From e6a99034b9a0365b18c319fb7380124b7f30f7a6 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 21:20:46 -0500 Subject: [PATCH 01/82] feat: incremental loading --- cmd/wire/check_cmd.go | 7 +- cmd/wire/diff_cmd.go | 9 +- cmd/wire/gen_cmd.go | 11 +- cmd/wire/incremental_flag.go | 60 ++ cmd/wire/main.go | 42 +- cmd/wire/show_cmd.go | 7 +- cmd/wire/watch_cmd.go | 11 +- go.mod | 2 +- internal/wire/cache_bypass.go | 17 + internal/wire/cache_test.go | 6 +- internal/wire/generate_package.go | 2 +- internal/wire/incremental.go | 65 ++ internal/wire/incremental_bench_test.go | 654 +++++++++++++ internal/wire/incremental_fingerprint.go | 421 ++++++++ internal/wire/incremental_fingerprint_test.go | 104 ++ internal/wire/incremental_graph.go | 306 ++++++ internal/wire/incremental_graph_test.go | 97 ++ internal/wire/incremental_manifest.go | 876 +++++++++++++++++ internal/wire/incremental_session.go | 95 ++ internal/wire/incremental_summary.go | 647 ++++++++++++ internal/wire/incremental_summary_test.go | 287 ++++++ internal/wire/incremental_test.go | 65 ++ internal/wire/load_debug.go | 304 ++++++ internal/wire/loader_test.go | 920 ++++++++++++++++++ internal/wire/local_fastpath.go | 556 +++++++++++ internal/wire/parse.go | 26 +- internal/wire/parser_lazy_loader.go | 64 +- internal/wire/parser_lazy_loader_test.go | 55 +- internal/wire/time_compat.go | 22 + internal/wire/timing.go | 8 + internal/wire/wire.go | 89 +- internal/wire/wire_test.go | 50 + 32 files changed, 5843 insertions(+), 42 deletions(-) create mode 100644 cmd/wire/incremental_flag.go create mode 100644 internal/wire/cache_bypass.go create mode 100644 internal/wire/incremental.go create mode 100644 internal/wire/incremental_bench_test.go create mode 100644 internal/wire/incremental_fingerprint.go create mode 100644 internal/wire/incremental_fingerprint_test.go create mode 100644 internal/wire/incremental_graph.go create mode 100644 internal/wire/incremental_graph_test.go create mode 100644 internal/wire/incremental_manifest.go create mode 100644 internal/wire/incremental_session.go create mode 100644 internal/wire/incremental_summary.go create mode 100644 internal/wire/incremental_summary_test.go create mode 100644 internal/wire/incremental_test.go create mode 100644 internal/wire/load_debug.go create mode 100644 internal/wire/local_fastpath.go create mode 100644 internal/wire/time_compat.go diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 7857437..71872d9 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -26,8 +26,9 @@ import ( ) type checkCmd struct { - tags string - profile profileFlags + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -52,6 +53,7 @@ func (*checkCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *checkCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -65,6 +67,7 @@ func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index 592cced..c7facca 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -29,9 +29,10 @@ import ( ) type diffCmd struct { - headerFile string - tags string - profile profileFlags + headerFile string + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -60,6 +61,7 @@ func (*diffCmd) Usage() string { func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -77,6 +79,7 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index 1532dd4..13b88ed 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -29,6 +29,7 @@ type genCmd struct { headerFile string prefixFileName string tags string + incremental optionalBoolFlag profile profileFlags } @@ -55,6 +56,7 @@ func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -68,6 +70,7 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { @@ -107,8 +110,12 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa // No Wire output. Maybe errors, maybe no Wire directives. continue } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) success = false diff --git a/cmd/wire/incremental_flag.go b/cmd/wire/incremental_flag.go new file mode 100644 index 0000000..2962128 --- /dev/null +++ b/cmd/wire/incremental_flag.go @@ -0,0 +1,60 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "strconv" + + "github.com/goforj/wire/internal/wire" +) + +type optionalBoolFlag struct { + value bool + set bool +} + +func (f *optionalBoolFlag) String() string { + if f == nil { + return "" + } + return strconv.FormatBool(f.value) +} + +func (f *optionalBoolFlag) Set(s string) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.value = v + f.set = true + return nil +} + +func (f *optionalBoolFlag) IsBoolFlag() bool { + return true +} + +func (f *optionalBoolFlag) apply(ctx context.Context) context.Context { + if f == nil || !f.set { + return ctx + } + return wire.WithIncremental(ctx, f.value) +} + +func addIncrementalFlag(f *optionalBoolFlag, fs *flag.FlagSet) { + fs.Var(f, "incremental", "enable the incremental engine (overrides "+wire.IncrementalEnvVar+")") +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 2f90783..3166531 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -34,6 +34,13 @@ import ( "github.com/google/subcommands" ) +var topLevelIncremental optionalBoolFlag + +const ( + ansiRed = "\033[31m" + ansiReset = "\033[0m" +) + // main wires up subcommands and executes the selected command. func main() { subcommands.Register(subcommands.CommandsCommand(), "") @@ -45,6 +52,7 @@ func main() { subcommands.Register(&genCmd{}, "") subcommands.Register(&watchCmd{}, "") subcommands.Register(&showCmd{}, "") + addIncrementalFlag(&topLevelIncremental, flag.CommandLine) flag.Parse() // Initialize the default logger to log to stderr. @@ -71,9 +79,9 @@ func main() { // Default to running the "gen" command. if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] { genCmd := &genCmd{} - os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine))) + os.Exit(int(genCmd.Execute(topLevelIncremental.apply(context.Background()), flag.CommandLine))) } - os.Exit(int(subcommands.Execute(context.Background()))) + os.Exit(int(subcommands.Execute(topLevelIncremental.apply(context.Background())))) } // installStackDumper registers signal handlers to dump goroutine stacks. @@ -200,6 +208,34 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - log.Println(strings.Replace(err.Error(), "\n", "\n\t", -1)) + msg := err.Error() + if strings.Contains(msg, "\n") { + logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) + continue + } + logMultilineError(msg) + } +} + +func logMultilineError(msg string) { + if shouldColorStderr() { + log.Print(ansiRed + msg + ansiReset) + return + } + log.Print(msg) +} + +func shouldColorStderr() bool { + if os.Getenv("NO_COLOR") != "" { + return false + } + term := os.Getenv("TERM") + if term == "" || term == "dumb" { + return false + } + info, err := os.Stderr.Stat() + if err != nil { + return false } + return (info.Mode() & os.ModeCharDevice) != 0 } diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 5a81b29..1313ade 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -34,8 +34,9 @@ import ( ) type showCmd struct { - tags string - profile profileFlags + tags string + incremental optionalBoolFlag + profile profileFlags } // Name returns the subcommand name. @@ -62,6 +63,7 @@ func (*showCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *showCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -75,6 +77,7 @@ func (cmd *showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index 779625f..13743cd 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -36,6 +36,7 @@ type watchCmd struct { headerFile string prefixFileName string tags string + incremental optionalBoolFlag profile profileFlags pollInterval time.Duration rescanInterval time.Duration @@ -63,6 +64,7 @@ func (cmd *watchCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") + addIncrementalFlag(&cmd.incremental, f) f.DurationVar(&cmd.pollInterval, "poll_interval", 250*time.Millisecond, "interval between file stat checks") f.DurationVar(&cmd.rescanInterval, "rescan_interval", 2*time.Second, "interval to rescan for new or removed Go files") cmd.profile.addFlags(f) @@ -77,6 +79,7 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } defer stop() ctx = withTiming(ctx, cmd.profile.timings) + ctx = cmd.incremental.apply(ctx) if cmd.pollInterval <= 0 { log.Println("poll_interval must be greater than zero") @@ -126,8 +129,12 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter if len(out.Content) == 0 { continue } - if err := out.Commit(); err == nil { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) success = false diff --git a/go.mod b/go.mod index e800555..5db8855 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/goforj/wire go 1.19 require ( + github.com/fsnotify/fsnotify v1.7.0 github.com/google/go-cmp v0.6.0 github.com/google/subcommands v1.2.0 github.com/pmezard/go-difflib v1.0.0 @@ -10,7 +11,6 @@ require ( ) require ( - github.com/fsnotify/fsnotify v1.7.0 // indirect golang.org/x/mod v0.20.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect diff --git a/internal/wire/cache_bypass.go b/internal/wire/cache_bypass.go new file mode 100644 index 0000000..b195eef --- /dev/null +++ b/internal/wire/cache_bypass.go @@ -0,0 +1,17 @@ +package wire + +import "context" + +type bypassPackageCacheKey struct{} + +func withBypassPackageCache(ctx context.Context) context.Context { + return context.WithValue(ctx, bypassPackageCacheKey{}, true) +} + +func bypassPackageCache(ctx context.Context) bool { + if ctx == nil { + return false + } + v, _ := ctx.Value(bypassPackageCacheKey{}).(bool) + return v +} diff --git a/internal/wire/cache_test.go b/internal/wire/cache_test.go index bc55bae..6ffb20a 100644 --- a/internal/wire/cache_test.go +++ b/internal/wire/cache_test.go @@ -123,8 +123,10 @@ func TestCacheInvalidation(t *testing.T) { if key2 == key { t.Fatal("expected cache key to change after source update") } - if cached, ok := readCache(key2); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after second Generate") + if !IncrementalEnabled(ctx, env) { + if cached, ok := readCache(key2); !ok || len(cached) == 0 { + t.Fatal("expected cache entry after second Generate") + } } } diff --git a/internal/wire/generate_package.go b/internal/wire/generate_package.go index de34aa6..01d3d20 100644 --- a/internal/wire/generate_package.go +++ b/internal/wire/generate_package.go @@ -47,7 +47,7 @@ func generateForPackage(ctx context.Context, pkg *packages.Package, loader *lazy res.Errs = append(res.Errs, err) return res } - if cacheKey != "" { + if cacheKey != "" && !bypassPackageCache(ctx) { cacheHitStart := time.Now() if cached, ok := readCache(cacheKey); ok { res.Content = cached diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go new file mode 100644 index 0000000..0bc334c --- /dev/null +++ b/internal/wire/incremental.go @@ -0,0 +1,65 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "strconv" + "strings" +) + +const IncrementalEnvVar = "WIRE_INCREMENTAL" + +type incrementalKey struct{} + +// WithIncremental overrides incremental-mode resolution for the provided +// context. This takes precedence over the environment variable. +func WithIncremental(ctx context.Context, enabled bool) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, incrementalKey{}, enabled) +} + +// IncrementalEnabled reports whether incremental mode is enabled for the +// current operation. A context override takes precedence over env. +func IncrementalEnabled(ctx context.Context, env []string) bool { + if ctx != nil { + if v := ctx.Value(incrementalKey{}); v != nil { + if enabled, ok := v.(bool); ok { + return enabled + } + } + } + raw, ok := lookupEnv(env, IncrementalEnvVar) + if !ok { + return false + } + enabled, err := strconv.ParseBool(strings.TrimSpace(raw)) + if err != nil { + return false + } + return enabled +} + +func lookupEnv(env []string, key string) (string, bool) { + prefix := key + "=" + for i := len(env) - 1; i >= 0; i-- { + if strings.HasPrefix(env[i], prefix) { + return strings.TrimPrefix(env[i], prefix), true + } + } + return "", false +} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go new file mode 100644 index 0000000..b981d23 --- /dev/null +++ b/internal/wire/incremental_bench_test.go @@ -0,0 +1,654 @@ +package wire + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "text/tabwriter" + "testing" + "time" +) + +const ( + largeBenchmarkTestPackageCount = 24 + largeBenchmarkHelperCount = 12 +) + +var largeBenchmarkSizes = []int{10, 100, 1000} + +func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + + for i := 0; i < b.N; i++ { + cacheRoot := b.TempDir() + osTempDir = func() string { return cacheRoot } + + root := b.TempDir() + writeIncrementalBenchmarkModule(b, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("baseline Generate returned errors: %v", errs) + } + + writeBenchmarkFile(b, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeBenchmarkFile(b, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + b.StartTimer() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + b.StopTimer() + if len(errs) > 0 { + b.Fatalf("incremental shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + b.Fatalf("unexpected Generate results: %+v", gens) + } + } +} + +func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { + runLargeRepoShapeChangeBenchmarks(b, false) +} + +func BenchmarkGenerateLargeRepoIncrementalShapeChange(b *testing.B) { + runLargeRepoShapeChangeBenchmarks(b, true) +} + +func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_TABLE") == "" { + t.Skip("set WIRE_BENCH_TABLE=1 to print the large-repo benchmark comparison table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + rows := make([]largeRepoBenchmarkRow, 0, len(largeBenchmarkSizes)) + for _, packageCount := range largeBenchmarkSizes { + coldNormal := measureLargeRepoColdOnce(t, repoRoot, packageCount, false) + coldIncremental := measureLargeRepoColdOnce(t, repoRoot, packageCount, true) + normal := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, false) + incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) + knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) + rows = append(rows, largeRepoBenchmarkRow{ + packageCount: packageCount, + coldNormal: coldNormal, + coldIncremental: coldIncremental, + normal: normal, + incremental: incremental, + knownToggle: knownToggle, + }) + } + + var out strings.Builder + tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "size\tcold normal\tcold incr\tcold delta\tcold x\tshape normal\tshape incr\tshape delta\tshape x\tknown toggle") + for _, row := range rows { + fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%.2fx\t%s\t%s\t%s\t%.2fx\t%s\n", + row.packageCount, + formatBenchmarkDuration(row.coldNormal), + formatBenchmarkDuration(row.coldIncremental), + formatPercentImprovement(row.coldNormal, row.coldIncremental), + speedupRatio(row.coldNormal, row.coldIncremental), + formatBenchmarkDuration(row.normal), + formatBenchmarkDuration(row.incremental), + formatPercentImprovement(row.normal, row.incremental), + speedupRatio(row.normal, row.incremental), + formatBenchmarkDuration(row.knownToggle), + ) + } + if err := tw.Flush(); err != nil { + t.Fatalf("flush benchmark table: %v", err) + } + fmt.Print(out.String()) +} + +func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_BREAKDOWN") == "" { + t.Skip("set WIRE_BENCH_BREAKDOWN=1 to print the large-repo shape-change breakdown table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + var out strings.Builder + tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, "size\tnormal total\tbase load\tlazy load\tincr total\tfast load\tfast generate\tspeedup") + for _, packageCount := range largeBenchmarkSizes { + normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) + incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) + fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\t%s\t%.2fx\n", + packageCount, + formatBenchmarkDuration(normal.total), + formatBenchmarkDuration(normal.label("load.packages.base.load")), + formatBenchmarkDuration(normal.label("load.packages.lazy.load")), + formatBenchmarkDuration(incremental.total), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), + speedupRatio(normal.total, incremental.total), + ) + } + if err := tw.Flush(); err != nil { + t.Fatalf("flush breakdown table: %v", err) + } + fmt.Print(out.String()) +} + +func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) +} + +func TestGenerateIncrementalLargeRepoShapeChangeMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := benchmarkRepoRoot(t) + root := t.TempDir() + writeLargeBenchmarkModule(t, repoRoot, root, largeBenchmarkTestPackageCount) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(t, root, largeBenchmarkTestPackageCount/2) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental large-repo Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected large-repo shape change to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal large-repo Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal results: %+v", normalGens) + } + if incrementalGens[0].OutputPath != normalGens[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("large-repo shape-changing incremental output differs from normal Generate output") + } +} + +func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + for _, packageCount := range largeBenchmarkSizes { + packageCount := packageCount + b.Run(fmt.Sprintf("size=%d", packageCount), func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StartTimer() + _ = measureLargeRepoShapeChangeOnce(b, repoRoot, packageCount, incremental) + b.StopTimer() + } + }) + } +} + +type largeRepoBenchmarkRow struct { + packageCount int + coldNormal time.Duration + coldIncremental time.Duration + normal time.Duration + incremental time.Duration + knownToggle time.Duration +} + +type shapeChangeTrace struct { + total time.Duration + labels map[string]time.Duration +} + +func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, packageCount/2) + + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected Generate results: %+v", gens) + } + return dur +} + +func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) shapeChangeTrace { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, packageCount/2) + + trace := shapeChangeTrace{labels: make(map[string]time.Duration)} + ctx = WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + if len(errs) > 0 { + tb.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected Generate results: %+v", gens) + } + return trace +} + +func (s shapeChangeTrace) label(name string) time.Duration { + if s.labels == nil { + return 0 + } + return s.labels[name] +} + +func measureLargeRepoColdOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := context.Background() + if incremental { + ctx = WithIncremental(ctx, true) + } + + start := time.Now() + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("cold Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected cold Generate results: %+v", gens) + } + return dur +} + +func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCount int) time.Duration { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + mutatedIndex := packageCount / 2 + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(tb, root, mutatedIndex) + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + tb.Fatalf("mutated Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected mutated Generate results: %+v", gens) + } + + writeLargeBenchmarkPackage(tb, root, mutatedIndex, false) + + start := time.Now() + gens, errs = Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + dur := time.Since(start) + if len(errs) > 0 { + tb.Fatalf("toggle Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("unexpected toggle Generate results: %+v", gens) + } + return dur +} + +func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { + if normal <= 0 { + return "0.0%" + } + improvement := 100 * (float64(normal-incremental) / float64(normal)) + return fmt.Sprintf("%.1f%%", improvement) +} + +func speedupRatio(normal time.Duration, incremental time.Duration) float64 { + if incremental <= 0 { + return 0 + } + return float64(normal) / float64(incremental) +} + +func formatBenchmarkDuration(d time.Duration) string { + switch { + case d >= time.Second: + return fmt.Sprintf("%.2fs", d.Seconds()) + case d >= time.Millisecond: + return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) + case d >= time.Microsecond: + return fmt.Sprintf("%.2fµs", float64(d)/float64(time.Microsecond)) + default: + return d.String() + } +} + +func writeLargeBenchmarkModule(tb testing.TB, repoRoot string, root string, packageCount int) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + wireImports := []string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"github.com/goforj/wire\"", + } + appImports := []string{ + "package app", + "", + "import (", + } + buildArgs := []string{"\twire.Build("} + argNames := make([]string, 0, packageCount) + for i := 0; i < packageCount; i++ { + pkgName := fmt.Sprintf("layer%02d", i) + wireImports = append(wireImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) + appImports = append(appImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) + buildArgs = append(buildArgs, fmt.Sprintf("\t\t%s.NewSet,", pkgName)) + argNames = append(argNames, fmt.Sprintf("dep%02d *%s.Token", i, pkgName)) + } + wireImports = append(wireImports, ")", "") + appImports = append(appImports, ")", "") + wireFile := append([]string{}, wireImports...) + wireFile = append(wireFile, "func Init() *App {") + wireFile = append(wireFile, buildArgs...) + wireFile = append(wireFile, "\t\tNewApp,", "\t)", "\treturn nil", "}", "") + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join(wireFile, "\n")) + + appGo := append(appImports[:len(appImports)-2], // reuse imports without trailing blank line + ")", + "", + "type App struct {", + "\tCount int", + "}", + "", + fmt.Sprintf("func NewApp(%s) *App {", strings.Join(argNames, ", ")), + fmt.Sprintf("\treturn &App{Count: %d}", packageCount), + "}", + "", + ) + writeBenchmarkFile(tb, filepath.Join(root, "app", "app.go"), strings.Join(appGo, "\n")) + + for i := 0; i < packageCount; i++ { + writeLargeBenchmarkPackage(tb, root, i, false) + } +} + +func mutateLargeBenchmarkModule(tb testing.TB, root string, mutatedIndex int) { + tb.Helper() + writeLargeBenchmarkPackage(tb, root, mutatedIndex, true) +} + +func writeLargeBenchmarkPackage(tb testing.TB, root string, index int, mutated bool) { + tb.Helper() + + pkgName := fmt.Sprintf("layer%02d", index) + pkgDir := filepath.Join(root, pkgName) + + writeBenchmarkFile(tb, filepath.Join(pkgDir, "helpers.go"), renderLargeBenchmarkHelpers(pkgName, index, mutated)) + writeBenchmarkFile(tb, filepath.Join(pkgDir, "wire.go"), renderLargeBenchmarkWire(pkgName, mutated)) +} + +func renderLargeBenchmarkHelpers(pkgName string, index int, mutated bool) string { + lines := []string{ + "package " + pkgName, + "", + "import (", + "\t\"fmt\"", + "\t\"strconv\"", + "\t\"strings\"", + ")", + "", + "type Config struct {", + "\tLabel string", + "}", + "", + "type Weight int", + "", + "type Token struct {", + "\tConfig Config", + "\tWeight Weight", + "}", + "", + fmt.Sprintf("func NewConfig() Config { return Config{Label: %q} }", pkgName), + "", + } + if mutated { + lines = append(lines, + fmt.Sprintf("func NewWeight() Weight { return Weight(%d) }", index+100), + "", + "func New(cfg Config, weight Weight) *Token {", + fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), + "\treturn &Token{Config: cfg, Weight: weight}", + "}", + "", + ) + } else { + lines = append(lines, + "func New(cfg Config) *Token {", + fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), + "\treturn &Token{Config: cfg}", + "}", + "", + ) + } + for i := 0; i < largeBenchmarkHelperCount; i++ { + lines = append(lines, fmt.Sprintf("func helper%02d() string {", i)) + lines = append(lines, fmt.Sprintf("\treturn strings.ToUpper(fmt.Sprintf(\"%%s-%%d\", %q, %d)) + strconv.Itoa(%d)", pkgName, i, index+i)) + lines = append(lines, "}", "") + } + return strings.Join(lines, "\n") +} + +func renderLargeBenchmarkWire(pkgName string, mutated bool) string { + lines := []string{ + "package " + pkgName, + "", + "import (", + "\t\"github.com/goforj/wire\"", + ")", + "", + } + if mutated { + lines = append(lines, "var NewSet = wire.NewSet(NewConfig, NewWeight, New)", "") + } else { + lines = append(lines, "var NewSet = wire.NewSet(NewConfig, New)", "") + } + return strings.Join(lines, "\n") +} + +func strconvQuote(s string) string { + return fmt.Sprintf("%q", s) +} + +func benchmarkRepoRoot(tb testing.TB) string { + tb.Helper() + wd, err := os.Getwd() + if err != nil { + tb.Fatalf("Getwd failed: %v", err) + } + repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) + if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { + tb.Fatalf("repo root not found at %s: %v", repoRoot, err) + } + return repoRoot +} + +func writeBenchmarkFile(tb testing.TB, path string, content string) { + tb.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + tb.Fatalf("MkdirAll failed: %v", err) + } + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + tb.Fatalf("WriteFile failed: %v", err) + } +} diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go new file mode 100644 index 0000000..886d07f --- /dev/null +++ b/internal/wire/incremental_fingerprint.go @@ -0,0 +1,421 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalFingerprintVersion = "wire-incremental-v1" + +type packageFingerprint struct { + Version string + WD string + Tags string + PkgPath string + Files []cacheFile + ShapeHash string + LocalImports []string +} + +type fingerprintStats struct { + localPackages int + metaHits int + metaMisses int + unchanged int + changed int +} + +type incrementalFingerprintSnapshot struct { + stats fingerprintStats + changed []string + fingerprints map[string]*packageFingerprint +} + +func analyzeIncrementalFingerprints(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + if !IncrementalEnabled(ctx, env) { + return nil + } + start := timeNow() + snapshot := collectIncrementalFingerprints(wd, tags, pkgs) + debugf(ctx, "incremental.fingerprint local_pkgs=%d meta_hits=%d meta_misses=%d unchanged=%d changed=%d total=%s", + snapshot.stats.localPackages, + snapshot.stats.metaHits, + snapshot.stats.metaMisses, + snapshot.stats.unchanged, + snapshot.stats.changed, + timeSince(start), + ) + if len(snapshot.changed) > 0 { + debugf(ctx, "incremental.fingerprint changed_pkgs=%s", strings.Join(snapshot.changed, ", ")) + } + return snapshot +} + +func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + all := collectAllPackages(pkgs) + moduleRoot := findModuleRoot(wd) + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: make(map[string]*packageFingerprint), + } + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + snapshot.stats.localPackages++ + files := packageFingerprintFiles(pkg) + if len(files) == 0 { + continue + } + sort.Strings(files) + metaFiles, err := buildCacheFiles(files) + if err != nil { + snapshot.stats.metaMisses++ + continue + } + key := incrementalFingerprintKey(wd, tags, pkg.PkgPath) + if prev, ok := readIncrementalFingerprint(key); ok && incrementalFingerprintMetaMatches(prev, wd, tags, pkg.PkgPath, metaFiles) { + snapshot.stats.metaHits++ + snapshot.stats.unchanged++ + snapshot.fingerprints[pkg.PkgPath] = prev + continue + } + snapshot.stats.metaMisses++ + fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) + if err != nil { + continue + } + prev, hadPrev := readIncrementalFingerprint(key) + writeIncrementalFingerprint(key, fp) + snapshot.fingerprints[pkg.PkgPath] = fp + if hadPrev && incrementalFingerprintEquivalent(prev, fp) { + snapshot.stats.unchanged++ + continue + } + snapshot.stats.changed++ + snapshot.changed = append(snapshot.changed, pkg.PkgPath) + } + sort.Strings(snapshot.changed) + return snapshot +} + +func packageFingerprintFiles(pkg *packages.Package) []string { + if pkg == nil { + return nil + } + if len(pkg.CompiledGoFiles) > 0 { + return append([]string(nil), pkg.CompiledGoFiles...) + } + return append([]string(nil), pkg.GoFiles...) +} + +func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { + if a == nil || b == nil { + return false + } + if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || filepath.Clean(a.WD) != filepath.Clean(b.WD) { + return false + } + if len(a.LocalImports) != len(b.LocalImports) { + return false + } + for i := range a.LocalImports { + if a.LocalImports[i] != b.LocalImports[i] { + return false + } + } + return true +} + +func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags string, pkgPath string, files []cacheFile) bool { + if prev == nil || prev.Version != incrementalFingerprintVersion { + return false + } + if filepath.Clean(prev.WD) != filepath.Clean(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { + return false + } + if len(prev.Files) != len(files) { + return false + } + for i := range prev.Files { + if prev.Files[i] != files[i] { + return false + } + } + return true +} + +func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, files []cacheFile) (*packageFingerprint, error) { + shapeHash, err := packageShapeHash(packageFingerprintFiles(pkg)) + if err != nil { + return nil, err + } + localImports := make([]string, 0, len(pkg.Imports)) + moduleRoot := findModuleRoot(wd) + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) == "local" { + localImports = append(localImports, imp.PkgPath) + } + } + sort.Strings(localImports) + return &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: filepath.Clean(wd), + Tags: tags, + PkgPath: pkg.PkgPath, + Files: append([]cacheFile(nil), files...), + ShapeHash: shapeHash, + LocalImports: localImports, + }, nil +} + +func packageShapeHash(files []string) (string, error) { + fset := token.NewFileSet() + var buf bytes.Buffer + for _, name := range files { + file, err := parser.ParseFile(fset, name, nil, parser.SkipObjectResolution) + if err != nil { + return "", err + } + stripFunctionBodies(file) + if err := printer.Fprint(&buf, fset, file); err != nil { + return "", err + } + buf.WriteByte(0) + } + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum[:]), nil +} + +func stripFunctionBodies(file *ast.File) { + if file == nil { + return + } + for _, decl := range file.Decls { + if fn, ok := decl.(*ast.FuncDecl); ok { + fn.Body = nil + fn.Doc = nil + } + } +} + +func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { + h := sha256.New() + h.Write([]byte(incrementalFingerprintVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + h.Write([]byte(pkgPath)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalFingerprintPath(key string) string { + return filepath.Join(cacheDir(), key+".ifp") +} + +func readIncrementalFingerprint(key string) (*packageFingerprint, bool) { + data, err := osReadFile(incrementalFingerprintPath(key)) + if err != nil { + return nil, false + } + fp, err := decodeIncrementalFingerprint(data) + if err != nil { + return nil, false + } + return fp, true +} + +func writeIncrementalFingerprint(key string, fp *packageFingerprint) { + data, err := encodeIncrementalFingerprint(fp) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".ifp-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalFingerprintPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func encodeIncrementalFingerprint(fp *packageFingerprint) ([]byte, error) { + var buf bytes.Buffer + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + writeCacheFiles := func(files []cacheFile) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { + return err + } + for _, f := range files { + if err := writeString(f.Path); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { + return err + } + } + return nil + } + writeStrings := func(items []string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(items))); err != nil { + return err + } + for _, item := range items { + if err := writeString(item); err != nil { + return err + } + } + return nil + } + if fp == nil { + return nil, fmt.Errorf("nil fingerprint") + } + for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := writeCacheFiles(fp.Files); err != nil { + return nil, err + } + if err := writeStrings(fp.LocalImports); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func decodeIncrementalFingerprint(data []byte) (*packageFingerprint, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + readCacheFiles := func() ([]cacheFile, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]cacheFile, 0, n) + for i := uint32(0); i < n; i++ { + path, err := readString() + if err != nil { + return nil, err + } + var size int64 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, err + } + var modTime int64 + if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { + return nil, err + } + out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) + } + return out, nil + } + readStrings := func() ([]string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]string, 0, n) + for i := uint32(0); i < n; i++ { + item, err := readString() + if err != nil { + return nil, err + } + out = append(out, item) + } + return out, nil + } + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + pkgPath, err := readString() + if err != nil { + return nil, err + } + shapeHash, err := readString() + if err != nil { + return nil, err + } + files, err := readCacheFiles() + if err != nil { + return nil, err + } + localImports, err := readStrings() + if err != nil { + return nil, err + } + return &packageFingerprint{ + Version: version, + WD: wd, + Tags: tags, + PkgPath: pkgPath, + ShapeHash: shapeHash, + Files: files, + LocalImports: localImports, + }, nil +} diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go new file mode 100644 index 0000000..afe81de --- /dev/null +++ b/internal/wire/incremental_fingerprint_test.go @@ -0,0 +1,104 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "os" + "path/filepath" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("body-only change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + +func TestIncrementalFingerprintRoundTrip(t *testing.T) { + fp := &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: "/tmp/app", + Tags: "dev", + PkgPath: "example.com/app", + ShapeHash: "shape", + Files: []cacheFile{{Path: "/tmp/app/pkg.go", Size: 12, ModTime: 34}}, + LocalImports: []string{"example.com/dep"}, + } + data, err := encodeIncrementalFingerprint(fp) + if err != nil { + t.Fatalf("encodeIncrementalFingerprint failed: %v", err) + } + got, err := decodeIncrementalFingerprint(data) + if err != nil { + t.Fatalf("decodeIncrementalFingerprint failed: %v", err) + } + if !incrementalFingerprintEquivalent(fp, got) { + t.Fatalf("fingerprint mismatch after round-trip: got %+v want %+v", got, fp) + } + if len(got.Files) != 1 || got.Files[0] != fp.Files[0] { + t.Fatalf("file metadata mismatch after round-trip: got %+v want %+v", got.Files, fp.Files) + } +} + +func TestCollectIncrementalFingerprintsTreatsBodyOnlyChangeAsUnchanged(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") + file := filepath.Join(root, "app", "app.go") + writeFile(t, file, "package app\n\nfunc Hello() string { return \"a\" }\n") + pkg := &packages.Package{ + PkgPath: "example.com/app", + CompiledGoFiles: []string{file}, + GoFiles: []string{file}, + Imports: map[string]*packages.Package{}, + } + + snapshot := collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) + if snapshot.stats.changed != 1 || len(snapshot.changed) != 1 || snapshot.changed[0] != pkg.PkgPath { + t.Fatalf("first run stats=%+v changed=%v", snapshot.stats, snapshot.changed) + } + + if err := os.WriteFile(file, []byte("package app\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + snapshot = collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) + if snapshot.stats.unchanged != 1 { + t.Fatalf("body-only change should be unchanged by shape, stats=%+v changed=%v", snapshot.stats, snapshot.changed) + } + if len(snapshot.changed) != 0 { + t.Fatalf("body-only change should not report changed packages, got %v", snapshot.changed) + } +} diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go new file mode 100644 index 0000000..66cf28d --- /dev/null +++ b/internal/wire/incremental_graph.go @@ -0,0 +1,306 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalGraphVersion = "wire-incremental-graph-v1" + +type incrementalGraph struct { + Version string + WD string + Tags string + Roots []string + LocalReverse map[string][]string +} + +func analyzeIncrementalGraph(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) { + if !IncrementalEnabled(ctx, env) || snapshot == nil { + return + } + graph := buildIncrementalGraph(wd, tags, pkgs) + writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) + if len(snapshot.changed) == 0 { + return + } + affected := affectedRoots(graph, snapshot.changed) + if len(affected) > 0 { + debugf(ctx, "incremental.graph changed=%s affected_roots=%s", stringsJoin(snapshot.changed), stringsJoin(affected)) + } else { + debugf(ctx, "incremental.graph changed=%s affected_roots=", stringsJoin(snapshot.changed)) + } +} + +func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *incrementalGraph { + moduleRoot := findModuleRoot(wd) + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: filepath.Clean(wd), + Tags: tags, + Roots: make([]string, 0, len(pkgs)), + LocalReverse: make(map[string][]string), + } + for _, pkg := range pkgs { + if pkg == nil { + continue + } + graph.Roots = append(graph.Roots, pkg.PkgPath) + } + sort.Strings(graph.Roots) + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) != "local" { + continue + } + graph.LocalReverse[imp.PkgPath] = append(graph.LocalReverse[imp.PkgPath], pkg.PkgPath) + } + } + for path := range graph.LocalReverse { + sort.Strings(graph.LocalReverse[path]) + } + return graph +} + +func affectedRoots(graph *incrementalGraph, changed []string) []string { + if graph == nil || len(changed) == 0 { + return nil + } + rootSet := make(map[string]struct{}, len(graph.Roots)) + for _, root := range graph.Roots { + rootSet[root] = struct{}{} + } + seen := make(map[string]struct{}) + queue := append([]string(nil), changed...) + affected := make(map[string]struct{}) + for len(queue) > 0 { + cur := queue[0] + queue = queue[1:] + if _, ok := seen[cur]; ok { + continue + } + seen[cur] = struct{}{} + if _, ok := rootSet[cur]; ok { + affected[cur] = struct{}{} + } + for _, next := range graph.LocalReverse[cur] { + if _, ok := seen[next]; !ok { + queue = append(queue, next) + } + } + } + out := make([]string, 0, len(affected)) + for root := range affected { + out = append(out, root) + } + sort.Strings(out) + return out +} + +func incrementalGraphKey(wd string, tags string, roots []string) string { + h := sha256.New() + h.Write([]byte(incrementalGraphVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + for _, root := range roots { + h.Write([]byte(root)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalGraphPath(key string) string { + return filepath.Join(cacheDir(), key+".igr") +} + +func writeIncrementalGraph(key string, graph *incrementalGraph) { + data, err := encodeIncrementalGraph(graph) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".igr-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalGraphPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func readIncrementalGraph(key string) (*incrementalGraph, bool) { + data, err := osReadFile(incrementalGraphPath(key)) + if err != nil { + return nil, false + } + graph, err := decodeIncrementalGraph(data) + if err != nil { + return nil, false + } + return graph, true +} + +func encodeIncrementalGraph(graph *incrementalGraph) ([]byte, error) { + if graph == nil { + return nil, fmt.Errorf("nil incremental graph") + } + var buf bytes.Buffer + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + for _, s := range []string{graph.Version, graph.WD, graph.Tags} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(graph.Roots))); err != nil { + return nil, err + } + for _, root := range graph.Roots { + if err := writeString(root); err != nil { + return nil, err + } + } + keys := make([]string, 0, len(graph.LocalReverse)) + for k := range graph.LocalReverse { + keys = append(keys, k) + } + sort.Strings(keys) + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(keys))); err != nil { + return nil, err + } + for _, k := range keys { + if err := writeString(k); err != nil { + return nil, err + } + children := append([]string(nil), graph.LocalReverse[k]...) + sort.Strings(children) + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(children))); err != nil { + return nil, err + } + for _, child := range children { + if err := writeString(child); err != nil { + return nil, err + } + } + } + return buf.Bytes(), nil +} + +func decodeIncrementalGraph(data []byte) (*incrementalGraph, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + var rootCount uint32 + if err := binary.Read(r, binary.LittleEndian, &rootCount); err != nil { + return nil, err + } + roots := make([]string, 0, rootCount) + for i := uint32(0); i < rootCount; i++ { + root, err := readString() + if err != nil { + return nil, err + } + roots = append(roots, root) + } + var edgeCount uint32 + if err := binary.Read(r, binary.LittleEndian, &edgeCount); err != nil { + return nil, err + } + reverse := make(map[string][]string, edgeCount) + for i := uint32(0); i < edgeCount; i++ { + k, err := readString() + if err != nil { + return nil, err + } + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + children := make([]string, 0, n) + for j := uint32(0); j < n; j++ { + child, err := readString() + if err != nil { + return nil, err + } + children = append(children, child) + } + reverse[k] = children + } + return &incrementalGraph{ + Version: version, + WD: wd, + Tags: tags, + Roots: roots, + LocalReverse: reverse, + }, nil +} + +func stringsJoin(items []string) string { + if len(items) == 0 { + return "" + } + return strings.Join(items, ",") +} diff --git a/internal/wire/incremental_graph_test.go b/internal/wire/incremental_graph_test.go new file mode 100644 index 0000000..8a91b54 --- /dev/null +++ b/internal/wire/incremental_graph_test.go @@ -0,0 +1,97 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "path/filepath" + "reflect" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestIncrementalGraphRoundTrip(t *testing.T) { + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: "/tmp/app", + Tags: "dev", + Roots: []string{"example.com/app", "example.com/other"}, + LocalReverse: map[string][]string{ + "example.com/dep": {"example.com/app"}, + "example.com/sub": {"example.com/dep", "example.com/other"}, + }, + } + data, err := encodeIncrementalGraph(graph) + if err != nil { + t.Fatalf("encodeIncrementalGraph failed: %v", err) + } + got, err := decodeIncrementalGraph(data) + if err != nil { + t.Fatalf("decodeIncrementalGraph failed: %v", err) + } + if !reflect.DeepEqual(got, graph) { + t.Fatalf("graph round-trip mismatch:\n got=%+v\nwant=%+v", got, graph) + } +} + +func TestAffectedRoots(t *testing.T) { + graph := &incrementalGraph{ + Roots: []string{"example.com/app", "example.com/other"}, + LocalReverse: map[string][]string{ + "example.com/dep": {"example.com/app"}, + "example.com/sub": {"example.com/dep", "example.com/other"}, + }, + } + got := affectedRoots(graph, []string{"example.com/sub"}) + want := []string{"example.com/app", "example.com/other"} + if !reflect.DeepEqual(got, want) { + t.Fatalf("affectedRoots=%v want %v", got, want) + } +} + +func TestBuildIncrementalGraph(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") + + appFile := filepath.Join(root, "app", "app.go") + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, appFile, "package app\n") + writeFile(t, depFile, "package dep\n") + + dep := &packages.Package{ + PkgPath: "example.com/test/dep", + CompiledGoFiles: []string{depFile}, + GoFiles: []string{depFile}, + Imports: map[string]*packages.Package{}, + } + app := &packages.Package{ + PkgPath: "example.com/test/app", + CompiledGoFiles: []string{appFile}, + GoFiles: []string{appFile}, + Imports: map[string]*packages.Package{ + "example.com/test/dep": dep, + }, + } + + graph := buildIncrementalGraph(root, "", []*packages.Package{app}) + if len(graph.Roots) != 1 || graph.Roots[0] != app.PkgPath { + t.Fatalf("unexpected roots: %v", graph.Roots) + } + got := graph.LocalReverse[dep.PkgPath] + want := []string{app.PkgPath} + if !reflect.DeepEqual(got, want) { + t.Fatalf("unexpected reverse edges: got=%v want=%v", got, want) + } +} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go new file mode 100644 index 0000000..ae36c77 --- /dev/null +++ b/internal/wire/incremental_manifest.go @@ -0,0 +1,876 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "context" + "crypto/sha256" + "encoding/binary" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "golang.org/x/tools/go/packages" +) + +const incrementalManifestVersion = "wire-incremental-manifest-v1" + +type incrementalManifest struct { + Version string + WD string + Tags string + Prefix string + HeaderHash string + EnvHash string + Patterns []string + LocalPackages []packageFingerprint + ExternalPkgs []externalPackageExport + ExternalFiles []cacheFile + ExtraFiles []cacheFile + Outputs []incrementalOutput +} + +type externalPackageExport struct { + PkgPath string + ExportFile string +} + +type incrementalOutput struct { + PkgPath string + OutputPath string + ContentKey string +} + +type incrementalPreloadState struct { + selectorKey string + manifest *incrementalManifest + valid bool + currentLocal []packageFingerprint + reason string +} + +func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { + state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) +} + +func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, ok bool) ([]GenerateResult, bool) { + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=no_manifest") + return nil, false + } + if state.valid { + results, ok := incrementalManifestOutputs(state.manifest) + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=outputs") + return nil, false + } + debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) + return results, true + } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { + if ok, _, _ := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); ok { + results, ok := incrementalManifestOutputs(archived) + if !ok { + debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") + return nil, false + } + writeIncrementalManifestFile(state.selectorKey, archived) + debugf(ctx, "incremental.preload_manifest state_hit outputs=%d", len(results)) + return results, true + } + debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) + return nil, false + } else { + debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) + return nil, false + } +} + +func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (*incrementalPreloadState, bool) { + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + manifest, ok := readIncrementalManifest(selectorKey) + if !ok { + return nil, false + } + valid, currentLocal, reason := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) + return &incrementalPreloadState{ + selectorKey: selectorKey, + manifest: manifest, + valid: valid, + currentLocal: currentLocal, + reason: reason, + }, true +} + +func readIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) ([]GenerateResult, bool) { + if snapshot == nil || snapshot.stats.changed != 0 { + return nil, false + } + key := incrementalManifestSelectorKey(wd, env, patterns, opts) + manifest, ok := readIncrementalManifest(key) + if !ok || !incrementalManifestValid(manifest, wd, env, patterns, opts, pkgs) { + return nil, false + } + results := make([]GenerateResult, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + content, ok := readCache(out.ContentKey) + if !ok { + return nil, false + } + results = append(results, GenerateResult{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + Content: content, + }) + } + debugf(ctx, "incremental.manifest hit outputs=%d", len(results)) + return results, true +} + +func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + if snapshot == nil || len(generated) == 0 { + return + } + externalPkgs := buildExternalPackageExports(wd, pkgs) + externalFiles, err := buildExternalPackageFiles(wd, pkgs) + if err != nil { + return + } + manifest := &incrementalManifest{ + Version: incrementalManifestVersion, + WD: filepath.Clean(wd), + Tags: opts.Tags, + Prefix: opts.PrefixOutputFile, + HeaderHash: headerHash(opts.Header), + EnvHash: envHash(env), + Patterns: sortedStrings(patterns), + LocalPackages: snapshotPackageFingerprints(snapshot), + ExternalPkgs: externalPkgs, + ExternalFiles: externalFiles, + ExtraFiles: extraCacheFiles(wd), + } + for _, out := range generated { + if len(out.Content) == 0 || out.OutputPath == "" { + continue + } + contentKey := incrementalContentKey(out.Content) + writeCache(contentKey, out.Content) + manifest.Outputs = append(manifest.Outputs, incrementalOutput{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + ContentKey: contentKey, + }) + } + if len(manifest.Outputs) == 0 { + return + } + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + stateKey := incrementalManifestStateKey(selectorKey, manifest.LocalPackages) + writeIncrementalManifestFile(selectorKey, manifest) + writeIncrementalManifestFile(stateKey, manifest) +} + +func incrementalManifestSelectorKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { + h := sha256.New() + h.Write([]byte(incrementalManifestVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(envHash(env))) + h.Write([]byte{0}) + h.Write([]byte(opts.Tags)) + h.Write([]byte{0}) + h.Write([]byte(opts.PrefixOutputFile)) + h.Write([]byte{0}) + h.Write([]byte(headerHash(opts.Header))) + h.Write([]byte{0}) + for _, p := range sortedStrings(patterns) { + h.Write([]byte(p)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func snapshotPackageFingerprints(snapshot *incrementalFingerprintSnapshot) []packageFingerprint { + if snapshot == nil || len(snapshot.fingerprints) == 0 { + return nil + } + paths := make([]string, 0, len(snapshot.fingerprints)) + for path := range snapshot.fingerprints { + paths = append(paths, path) + } + sort.Strings(paths) + out := make([]packageFingerprint, 0, len(paths)) + for _, path := range paths { + if fp := snapshot.fingerprints[path]; fp != nil { + out = append(out, *fp) + } + } + return out +} + +func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile, error) { + moduleRoot := findModuleRoot(wd) + seen := make(map[string]struct{}) + var files []string + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) == "local" { + continue + } + names := pkg.CompiledGoFiles + if len(names) == 0 { + names = pkg.GoFiles + } + for _, name := range names { + clean := filepath.Clean(name) + if _, ok := seen[clean]; ok { + continue + } + seen[clean] = struct{}{} + files = append(files, clean) + } + } + sort.Strings(files) + return buildCacheFiles(files) +} + +func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { + moduleRoot := findModuleRoot(wd) + out := make([]externalPackageExport, 0) + for _, pkg := range collectAllPackages(pkgs) { + if classifyPackageLocation(moduleRoot, pkg) == "local" { + continue + } + if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { + continue + } + out = append(out, externalPackageExport{ + PkgPath: pkg.PkgPath, + ExportFile: pkg.ExportFile, + }) + } + sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) + return out +} + +func incrementalManifestValid(manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) bool { + if manifest == nil || manifest.Version != incrementalManifestVersion { + return false + } + if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return false + } + if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { + return false + } + if len(manifest.Patterns) != len(patterns) { + return false + } + for i, p := range sortedStrings(patterns) { + if manifest.Patterns[i] != p { + return false + } + } + currentExternal, err := buildExternalPackageFiles(wd, pkgs) + if err != nil || len(currentExternal) != len(manifest.ExternalFiles) { + return false + } + for i := range currentExternal { + if currentExternal[i] != manifest.ExternalFiles[i] { + return false + } + } + if len(manifest.ExtraFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + if err != nil || len(current) != len(manifest.ExtraFiles) { + return false + } + for i := range current { + if current[i] != manifest.ExtraFiles[i] { + return false + } + } + } + return len(manifest.Outputs) > 0 +} + +func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) (bool, []packageFingerprint, string) { + if manifest == nil || manifest.Version != incrementalManifestVersion { + return false, nil, "version" + } + if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return false, nil, "config" + } + if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { + return false, nil, "env" + } + if len(manifest.Patterns) != len(patterns) { + return false, nil, "patterns.length" + } + for i, p := range sortedStrings(patterns) { + if manifest.Patterns[i] != p { + return false, nil, "patterns.value" + } + } + if len(manifest.ExtraFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + if err != nil || len(current) != len(manifest.ExtraFiles) { + return false, nil, "extra_files" + } + for i := range current { + if current[i] != manifest.ExtraFiles[i] { + return false, nil, "extra_files.diff" + } + } + } + currentLocal, ok, reason := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) + if !ok { + return false, currentLocal, "local_packages." + reason + } + if len(manifest.ExternalFiles) > 0 { + current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) + if err != nil || len(current) != len(manifest.ExternalFiles) { + return false, currentLocal, "external_files" + } + for i := range current { + if current[i] != manifest.ExternalFiles[i] { + return false, currentLocal, "external_files.diff" + } + } + } + if len(manifest.Outputs) == 0 { + return false, currentLocal, "outputs" + } + return true, currentLocal, "" +} + +func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) ([]packageFingerprint, bool, string) { + currentState := make([]packageFingerprint, 0, len(local)) + var firstReason string + for _, fp := range local { + if len(fp.Files) == 0 { + if firstReason == "" { + firstReason = fp.PkgPath + ".files" + } + continue + } + storedFiles := filesFromMeta(fp.Files) + if len(storedFiles) == 0 { + if firstReason == "" { + firstReason = fp.PkgPath + ".stored_files" + } + continue + } + currentMeta, err := buildCacheFiles(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".meta_error" + } + continue + } + currentFP := fp + currentFP.Files = append([]cacheFile(nil), currentMeta...) + sameMeta := len(currentMeta) == len(fp.Files) + if sameMeta { + for i := range currentMeta { + if currentMeta[i] != fp.Files[i] { + sameMeta = false + break + } + } + } + if !sameMeta { + shapeHash, err := packageShapeHash(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_error" + } + continue + } + currentFP.ShapeHash = shapeHash + if shapeHash != fp.ShapeHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_mismatch" + } + } + } + if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".dir_scan_error" + } + continue + } else if changed { + debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) + if firstReason == "" { + firstReason = fp.PkgPath + ".introduced_relevant_files" + } + } + currentState = append(currentState, currentFP) + } + if firstReason != "" { + return currentState, false, firstReason + } + return currentState, true, "" +} + +func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { + results := make([]GenerateResult, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + content, ok := readCache(out.ContentKey) + if !ok { + return nil, false + } + results = append(results, GenerateResult{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + Content: content, + }) + } + return results, true +} + +func readStateIncrementalManifest(selectorKey string, local []packageFingerprint) *incrementalManifest { + if len(local) == 0 { + return nil + } + stateKey := incrementalManifestStateKey(selectorKey, local) + manifest, ok := readIncrementalManifest(stateKey) + if !ok { + return nil + } + return manifest +} + +func incrementalManifestStateKey(selectorKey string, local []packageFingerprint) string { + h := sha256.New() + h.Write([]byte(selectorKey)) + h.Write([]byte{0}) + for _, fp := range snapshotPackageFingerprints(&incrementalFingerprintSnapshot{fingerprints: fingerprintsFromSlice(local)}) { + h.Write([]byte(fp.PkgPath)) + h.Write([]byte{0}) + h.Write([]byte(fp.ShapeHash)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func fingerprintsFromSlice(local []packageFingerprint) map[string]*packageFingerprint { + if len(local) == 0 { + return nil + } + out := make(map[string]*packageFingerprint, len(local)) + for i := range local { + fp := local[i] + out[fp.PkgPath] = &fp + } + return out +} + +func filesFromMeta(files []cacheFile) []string { + out := make([]string, 0, len(files)) + for _, f := range files { + out = append(out, filepath.Clean(f.Path)) + } + sort.Strings(out) + return out +} + +func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { + dirs := make(map[string]struct{}) + old := make(map[string]struct{}, len(files)) + for _, f := range files { + path := filepath.Clean(f.Path) + dirs[filepath.Dir(path)] = struct{}{} + old[path] = struct{}{} + } + for dir := range dirs { + entries, err := os.ReadDir(dir) + if err != nil { + return false, err + } + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".go") { + continue + } + if strings.HasSuffix(name, "_test.go") { + continue + } + if strings.HasSuffix(name, "wire_gen.go") { + continue + } + path := filepath.Clean(filepath.Join(dir, name)) + if _, ok := old[path]; !ok { + return true, nil + } + } + } + return false, nil +} + +func incrementalManifestPath(key string) string { + return filepath.Join(cacheDir(), key+".iman") +} + +func readIncrementalManifest(key string) (*incrementalManifest, bool) { + data, err := osReadFile(incrementalManifestPath(key)) + if err != nil { + return nil, false + } + manifest, err := decodeIncrementalManifest(data) + if err != nil { + return nil, false + } + return manifest, true +} + +func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { + data, err := encodeIncrementalManifest(manifest) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".iman-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalManifestPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { + var buf bytes.Buffer + if manifest == nil { + return nil, fmt.Errorf("nil incremental manifest") + } + writeString := func(s string) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { + return err + } + _, err := buf.WriteString(s) + return err + } + writeCacheFiles := func(files []cacheFile) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { + return err + } + for _, f := range files { + if err := writeString(f.Path); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { + return err + } + } + return nil + } + writeExternalPkgs := func(pkgs []externalPackageExport) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(pkgs))); err != nil { + return err + } + for _, pkg := range pkgs { + if err := writeString(pkg.PkgPath); err != nil { + return err + } + if err := writeString(pkg.ExportFile); err != nil { + return err + } + } + return nil + } + writeFingerprints := func(fps []packageFingerprint) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fps))); err != nil { + return err + } + for _, fp := range fps { + for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { + if err := writeString(s); err != nil { + return err + } + } + if err := writeCacheFiles(fp.Files); err != nil { + return err + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fp.LocalImports))); err != nil { + return err + } + for _, imp := range fp.LocalImports { + if err := writeString(imp); err != nil { + return err + } + } + } + return nil + } + writeOutputs := func(outputs []incrementalOutput) error { + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(outputs))); err != nil { + return err + } + for _, out := range outputs { + for _, s := range []string{out.PkgPath, out.OutputPath, out.ContentKey} { + if err := writeString(s); err != nil { + return err + } + } + } + return nil + } + for _, s := range []string{manifest.Version, manifest.WD, manifest.Tags, manifest.Prefix, manifest.HeaderHash, manifest.EnvHash} { + if err := writeString(s); err != nil { + return nil, err + } + } + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(manifest.Patterns))); err != nil { + return nil, err + } + for _, p := range manifest.Patterns { + if err := writeString(p); err != nil { + return nil, err + } + } + if err := writeFingerprints(manifest.LocalPackages); err != nil { + return nil, err + } + if err := writeExternalPkgs(manifest.ExternalPkgs); err != nil { + return nil, err + } + if err := writeCacheFiles(manifest.ExternalFiles); err != nil { + return nil, err + } + if err := writeCacheFiles(manifest.ExtraFiles); err != nil { + return nil, err + } + if err := writeOutputs(manifest.Outputs); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func decodeIncrementalManifest(data []byte) (*incrementalManifest, error) { + r := bytes.NewReader(data) + readString := func() (string, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return "", err + } + buf := make([]byte, n) + if _, err := r.Read(buf); err != nil { + return "", err + } + return string(buf), nil + } + readCacheFiles := func() ([]cacheFile, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]cacheFile, 0, n) + for i := uint32(0); i < n; i++ { + path, err := readString() + if err != nil { + return nil, err + } + var size int64 + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return nil, err + } + var modTime int64 + if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { + return nil, err + } + out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) + } + return out, nil + } + readExternalPkgs := func() ([]externalPackageExport, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]externalPackageExport, 0, n) + for i := uint32(0); i < n; i++ { + pkgPath, err := readString() + if err != nil { + return nil, err + } + exportFile, err := readString() + if err != nil { + return nil, err + } + out = append(out, externalPackageExport{PkgPath: pkgPath, ExportFile: exportFile}) + } + return out, nil + } + readFingerprints := func() ([]packageFingerprint, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]packageFingerprint, 0, n) + for i := uint32(0); i < n; i++ { + version, err := readString() + if err != nil { + return nil, err + } + wd, err := readString() + if err != nil { + return nil, err + } + tags, err := readString() + if err != nil { + return nil, err + } + pkgPath, err := readString() + if err != nil { + return nil, err + } + shapeHash, err := readString() + if err != nil { + return nil, err + } + files, err := readCacheFiles() + if err != nil { + return nil, err + } + var importCount uint32 + if err := binary.Read(r, binary.LittleEndian, &importCount); err != nil { + return nil, err + } + localImports := make([]string, 0, importCount) + for j := uint32(0); j < importCount; j++ { + imp, err := readString() + if err != nil { + return nil, err + } + localImports = append(localImports, imp) + } + out = append(out, packageFingerprint{ + Version: version, + WD: wd, + Tags: tags, + PkgPath: pkgPath, + ShapeHash: shapeHash, + Files: files, + LocalImports: localImports, + }) + } + return out, nil + } + readOutputs := func() ([]incrementalOutput, error) { + var n uint32 + if err := binary.Read(r, binary.LittleEndian, &n); err != nil { + return nil, err + } + out := make([]incrementalOutput, 0, n) + for i := uint32(0); i < n; i++ { + pkgPath, err := readString() + if err != nil { + return nil, err + } + outputPath, err := readString() + if err != nil { + return nil, err + } + contentKey, err := readString() + if err != nil { + return nil, err + } + out = append(out, incrementalOutput{PkgPath: pkgPath, OutputPath: outputPath, ContentKey: contentKey}) + } + return out, nil + } + fields := make([]string, 6) + for i := range fields { + s, err := readString() + if err != nil { + return nil, err + } + fields[i] = s + } + var patternCount uint32 + if err := binary.Read(r, binary.LittleEndian, &patternCount); err != nil { + return nil, err + } + patterns := make([]string, 0, patternCount) + for i := uint32(0); i < patternCount; i++ { + p, err := readString() + if err != nil { + return nil, err + } + patterns = append(patterns, p) + } + localPackages, err := readFingerprints() + if err != nil { + return nil, err + } + externalPkgs, err := readExternalPkgs() + if err != nil { + return nil, err + } + externalFiles, err := readCacheFiles() + if err != nil { + return nil, err + } + extraFiles, err := readCacheFiles() + if err != nil { + return nil, err + } + outputs, err := readOutputs() + if err != nil { + return nil, err + } + return &incrementalManifest{ + Version: fields[0], + WD: fields[1], + Tags: fields[2], + Prefix: fields[3], + HeaderHash: fields[4], + EnvHash: fields[5], + Patterns: patterns, + LocalPackages: localPackages, + ExternalPkgs: externalPkgs, + ExternalFiles: externalFiles, + ExtraFiles: extraFiles, + Outputs: outputs, + }, nil +} + +func incrementalContentKey(content []byte) string { + sum := sha256.Sum256(content) + return fmt.Sprintf("%x", sum[:]) +} diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go new file mode 100644 index 0000000..fda6605 --- /dev/null +++ b/internal/wire/incremental_session.go @@ -0,0 +1,95 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "crypto/sha256" + "encoding/hex" + "go/ast" + "go/token" + "path/filepath" + "strings" + "sync" +) + +type incrementalSession struct { + fset *token.FileSet + mu sync.Mutex + parsedDeps map[string]cachedParsedFile +} + +type cachedParsedFile struct { + hash string + file *ast.File +} + +var incrementalSessions sync.Map + +func sessionKey(wd string, env []string, tags string) string { + var b strings.Builder + b.WriteString(filepath.Clean(wd)) + b.WriteByte('\n') + b.WriteString(tags) + b.WriteByte('\n') + for _, entry := range env { + b.WriteString(entry) + b.WriteByte('\x00') + } + return b.String() +} + +func getIncrementalSession(wd string, env []string, tags string) *incrementalSession { + key := sessionKey(wd, env, tags) + if session, ok := incrementalSessions.Load(key); ok { + return session.(*incrementalSession) + } + session := &incrementalSession{ + fset: token.NewFileSet(), + parsedDeps: make(map[string]cachedParsedFile), + } + actual, _ := incrementalSessions.LoadOrStore(key, session) + return actual.(*incrementalSession) +} + +func (s *incrementalSession) getParsedDep(filename string, src []byte) (*ast.File, bool) { + if s == nil { + return nil, false + } + hash := hashSource(src) + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.parsedDeps[filepath.Clean(filename)] + if !ok || entry.hash != hash { + return nil, false + } + return entry.file, true +} + +func (s *incrementalSession) storeParsedDep(filename string, src []byte, file *ast.File) { + if s == nil || file == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + s.parsedDeps[filepath.Clean(filename)] = cachedParsedFile{ + hash: hashSource(src), + file: file, + } +} + +func hashSource(src []byte) string { + sum := sha256.Sum256(src) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go new file mode 100644 index 0000000..faaa9b8 --- /dev/null +++ b/internal/wire/incremental_summary.go @@ -0,0 +1,647 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "bytes" + "crypto/sha256" + "encoding/binary" + "fmt" + "go/ast" + "go/types" + "path/filepath" + "sort" + + "golang.org/x/tools/go/packages" +) + +const incrementalSummaryVersion = "wire-incremental-summary-v1" + +type packageSummary struct { + Version string + WD string + Tags string + PkgPath string + ShapeHash string + LocalImports []string + ProviderSets []providerSetSummary + Injectors []injectorSummary +} + +type providerSetSummary struct { + VarName string + Providers []providerSummary + Imports []providerSetRefSummary + Bindings []ifaceBindingSummary + Values []string + Fields []fieldSummary + InputTypes []string +} + +type providerSummary struct { + PkgPath string + Name string + Args []providerInputSummary + Out []string + Varargs bool + IsStruct bool + HasCleanup bool + HasErr bool +} + +type providerInputSummary struct { + Type string + FieldName string +} + +type providerSetRefSummary struct { + PkgPath string + VarName string +} + +type ifaceBindingSummary struct { + Iface string + Provided string +} + +type fieldSummary struct { + PkgPath string + Parent string + Name string + Out []string +} + +type injectorSummary struct { + Name string + Inputs []string + Output string + Build providerSetSummary +} + +type packageSummarySnapshot struct { + Changed map[string]*packageSummary + Unchanged map[string]*packageSummary +} + +func incrementalSummaryKey(wd string, tags string, pkgPath string) string { + h := sha256.New() + h.Write([]byte(incrementalSummaryVersion)) + h.Write([]byte{0}) + h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte{0}) + h.Write([]byte(tags)) + h.Write([]byte{0}) + h.Write([]byte(pkgPath)) + return fmt.Sprintf("%x", h.Sum(nil)) +} + +func incrementalSummaryPath(key string) string { + return filepath.Join(cacheDir(), key+".isum") +} + +func readIncrementalPackageSummary(key string) (*packageSummary, bool) { + data, err := osReadFile(incrementalSummaryPath(key)) + if err != nil { + return nil, false + } + summary, err := decodeIncrementalSummary(data) + if err != nil { + return nil, false + } + return summary, true +} + +func writeIncrementalPackageSummary(key string, summary *packageSummary) { + data, err := encodeIncrementalSummary(summary) + if err != nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, key+".isum-") + if err != nil { + return + } + _, writeErr := tmp.Write(data) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), incrementalSummaryPath(key)); err != nil { + osRemove(tmp.Name()) + } +} + +func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { + if loader == nil || len(pkgs) == 0 { + return + } + moduleRoot := findModuleRoot(loader.wd) + all := collectAllPackages(pkgs) + for path, pkg := range loader.loaded { + if pkg != nil { + all[path] = pkg + } + } + allPkgs := make([]*packages.Package, 0, len(all)) + for _, pkg := range all { + allPkgs = append(allPkgs, pkg) + } + oc := newObjectCache(allPkgs, loader) + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + continue + } + summary, err := buildPackageSummary(loader, oc, pkg) + if err != nil { + continue + } + writeIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath), summary) + } +} + +func collectIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) *packageSummarySnapshot { + if loader == nil || loader.fingerprints == nil { + return nil + } + snapshot := &packageSummarySnapshot{ + Changed: make(map[string]*packageSummary), + Unchanged: make(map[string]*packageSummary), + } + changed := make(map[string]struct{}, len(loader.fingerprints.changed)) + for _, path := range loader.fingerprints.changed { + changed[path] = struct{}{} + } + moduleRoot := findModuleRoot(loader.wd) + oc := newObjectCache(pkgs, loader) + for _, pkg := range collectAllPackages(pkgs) { + if pkg == nil { + continue + } + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + if _, ok := changed[pkg.PkgPath]; ok { + if pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + loaded, errs := oc.ensurePackage(pkg.PkgPath) + if len(errs) > 0 { + continue + } + pkg = loaded + } + if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { + continue + } + summary, err := buildPackageSummary(loader, oc, pkg) + if err != nil { + continue + } + snapshot.Changed[pkg.PkgPath] = summary + continue + } + if summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath)); ok { + snapshot.Unchanged[pkg.PkgPath] = summary + } + } + return snapshot +} + +func buildPackageSummary(loader *lazyLoader, oc *objectCache, pkg *packages.Package) (*packageSummary, error) { + if loader == nil || oc == nil || pkg == nil { + return nil, fmt.Errorf("missing loader, object cache, or package") + } + summary := &packageSummary{ + Version: incrementalSummaryVersion, + WD: filepath.Clean(loader.wd), + Tags: loader.tags, + PkgPath: pkg.PkgPath, + } + if snapshot := loader.fingerprints; snapshot != nil { + if fp := snapshot.fingerprints[pkg.PkgPath]; fp != nil { + summary.ShapeHash = fp.ShapeHash + summary.LocalImports = append(summary.LocalImports, fp.LocalImports...) + } + } + scope := pkg.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + if !isProviderSetType(obj.Type()) { + continue + } + item, errs := oc.get(obj) + if len(errs) > 0 { + continue + } + pset, ok := item.(*ProviderSet) + if !ok { + continue + } + summary.ProviderSets = append(summary.ProviderSets, summarizeProviderSet(pset)) + } + sort.Slice(summary.ProviderSets, func(i, j int) bool { + return summary.ProviderSets[i].VarName < summary.ProviderSets[j].VarName + }) + for _, file := range pkg.Syntax { + for _, decl := range file.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + buildCall, err := findInjectorBuild(pkg.TypesInfo, fn) + if err != nil || buildCall == nil { + continue + } + sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) + ins, out, err := injectorFuncSignature(sig) + if err != nil { + continue + } + injectorArgs := &InjectorArgs{ + Name: fn.Name.Name, + Tuple: ins, + Pos: fn.Pos(), + } + set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") + if len(errs) > 0 { + continue + } + summary.Injectors = append(summary.Injectors, injectorSummary{ + Name: fn.Name.Name, + Inputs: summarizeTuple(ins), + Output: summaryTypeString(out.out), + Build: summarizeProviderSet(set), + }) + } + } + sort.Slice(summary.Injectors, func(i, j int) bool { + return summary.Injectors[i].Name < summary.Injectors[j].Name + }) + return summary, nil +} + +func summarizeProviderSet(pset *ProviderSet) providerSetSummary { + if pset == nil { + return providerSetSummary{} + } + summary := providerSetSummary{ + VarName: pset.VarName, + } + for _, provider := range pset.Providers { + summary.Providers = append(summary.Providers, summarizeProvider(provider)) + } + for _, imported := range pset.Imports { + summary.Imports = append(summary.Imports, providerSetRefSummary{ + PkgPath: imported.PkgPath, + VarName: imported.VarName, + }) + } + for _, binding := range pset.Bindings { + summary.Bindings = append(summary.Bindings, ifaceBindingSummary{ + Iface: summaryTypeString(binding.Iface), + Provided: summaryTypeString(binding.Provided), + }) + } + for _, value := range pset.Values { + summary.Values = append(summary.Values, summaryTypeString(value.Out)) + } + for _, field := range pset.Fields { + item := fieldSummary{ + Parent: summaryTypeString(field.Parent), + Name: field.Name, + Out: summarizeTypes(field.Out), + } + if field.Pkg != nil { + item.PkgPath = field.Pkg.Path() + } + summary.Fields = append(summary.Fields, item) + } + if pset.InjectorArgs != nil { + summary.InputTypes = summarizeTuple(pset.InjectorArgs.Tuple) + } + sort.Slice(summary.Providers, func(i, j int) bool { + return summary.Providers[i].PkgPath+"."+summary.Providers[i].Name < summary.Providers[j].PkgPath+"."+summary.Providers[j].Name + }) + sort.Slice(summary.Imports, func(i, j int) bool { + return summary.Imports[i].PkgPath+"."+summary.Imports[i].VarName < summary.Imports[j].PkgPath+"."+summary.Imports[j].VarName + }) + sort.Slice(summary.Bindings, func(i, j int) bool { + return summary.Bindings[i].Iface+":"+summary.Bindings[i].Provided < summary.Bindings[j].Iface+":"+summary.Bindings[j].Provided + }) + sort.Strings(summary.Values) + sort.Slice(summary.Fields, func(i, j int) bool { + return summary.Fields[i].Parent+"."+summary.Fields[i].Name < summary.Fields[j].Parent+"."+summary.Fields[j].Name + }) + sort.Strings(summary.InputTypes) + return summary +} + +func summarizeProvider(provider *Provider) providerSummary { + summary := providerSummary{ + Name: provider.Name, + Varargs: provider.Varargs, + IsStruct: provider.IsStruct, + HasCleanup: provider.HasCleanup, + HasErr: provider.HasErr, + Out: summarizeTypes(provider.Out), + } + if provider.Pkg != nil { + summary.PkgPath = provider.Pkg.Path() + } + for _, arg := range provider.Args { + summary.Args = append(summary.Args, providerInputSummary{ + Type: summaryTypeString(arg.Type), + FieldName: arg.FieldName, + }) + } + return summary +} + +func summarizeTuple(tuple *types.Tuple) []string { + if tuple == nil { + return nil + } + out := make([]string, 0, tuple.Len()) + for i := 0; i < tuple.Len(); i++ { + out = append(out, summaryTypeString(tuple.At(i).Type())) + } + return out +} + +func summarizeTypes(typesList []types.Type) []string { + out := make([]string, 0, len(typesList)) + for _, t := range typesList { + out = append(out, summaryTypeString(t)) + } + return out +} + +func summaryTypeString(t types.Type) string { + if t == nil { + return "" + } + return types.TypeString(t, func(pkg *types.Package) string { + if pkg == nil { + return "" + } + return pkg.Path() + }) +} + +func encodeIncrementalSummary(summary *packageSummary) ([]byte, error) { + if summary == nil { + return nil, fmt.Errorf("nil package summary") + } + var buf bytes.Buffer + enc := binarySummaryEncoder{buf: &buf} + enc.string(summary.Version) + enc.string(summary.WD) + enc.string(summary.Tags) + enc.string(summary.PkgPath) + enc.string(summary.ShapeHash) + enc.strings(summary.LocalImports) + enc.providerSets(summary.ProviderSets) + enc.u32(uint32(len(summary.Injectors))) + for _, injector := range summary.Injectors { + enc.string(injector.Name) + enc.strings(injector.Inputs) + enc.string(injector.Output) + enc.providerSet(injector.Build) + } + if enc.err != nil { + return nil, enc.err + } + return buf.Bytes(), nil +} + +func decodeIncrementalSummary(data []byte) (*packageSummary, error) { + dec := binarySummaryDecoder{r: bytes.NewReader(data)} + summary := &packageSummary{ + Version: dec.string(), + WD: dec.string(), + Tags: dec.string(), + PkgPath: dec.string(), + ShapeHash: dec.string(), + } + summary.LocalImports = dec.strings() + summary.ProviderSets = dec.providerSets() + for n := dec.u32(); n > 0; n-- { + summary.Injectors = append(summary.Injectors, injectorSummary{ + Name: dec.string(), + Inputs: dec.strings(), + Output: dec.string(), + Build: dec.providerSet(), + }) + } + if dec.err != nil { + return nil, dec.err + } + return summary, nil +} + +type binarySummaryEncoder struct { + buf *bytes.Buffer + err error +} + +func (e *binarySummaryEncoder) u32(v uint32) { + if e.err != nil { + return + } + e.err = binary.Write(e.buf, binary.LittleEndian, v) +} + +func (e *binarySummaryEncoder) string(s string) { + e.u32(uint32(len(s))) + if e.err != nil { + return + } + _, e.err = e.buf.WriteString(s) +} + +func (e *binarySummaryEncoder) bool(v bool) { + if e.err != nil { + return + } + var b byte + if v { + b = 1 + } + e.err = e.buf.WriteByte(b) +} + +func (e *binarySummaryEncoder) strings(values []string) { + e.u32(uint32(len(values))) + for _, v := range values { + e.string(v) + } +} + +func (e *binarySummaryEncoder) providerSets(values []providerSetSummary) { + e.u32(uint32(len(values))) + for _, value := range values { + e.providerSet(value) + } +} + +func (e *binarySummaryEncoder) providerSet(value providerSetSummary) { + e.string(value.VarName) + e.u32(uint32(len(value.Providers))) + for _, provider := range value.Providers { + e.string(provider.PkgPath) + e.string(provider.Name) + e.u32(uint32(len(provider.Args))) + for _, arg := range provider.Args { + e.string(arg.Type) + e.string(arg.FieldName) + } + e.strings(provider.Out) + e.bool(provider.Varargs) + e.bool(provider.IsStruct) + e.bool(provider.HasCleanup) + e.bool(provider.HasErr) + } + e.u32(uint32(len(value.Imports))) + for _, imported := range value.Imports { + e.string(imported.PkgPath) + e.string(imported.VarName) + } + e.u32(uint32(len(value.Bindings))) + for _, binding := range value.Bindings { + e.string(binding.Iface) + e.string(binding.Provided) + } + e.strings(value.Values) + e.u32(uint32(len(value.Fields))) + for _, field := range value.Fields { + e.string(field.PkgPath) + e.string(field.Parent) + e.string(field.Name) + e.strings(field.Out) + } + e.strings(value.InputTypes) +} + +type binarySummaryDecoder struct { + r *bytes.Reader + err error +} + +func (d *binarySummaryDecoder) u32() uint32 { + if d.err != nil { + return 0 + } + var v uint32 + d.err = binary.Read(d.r, binary.LittleEndian, &v) + return v +} + +func (d *binarySummaryDecoder) string() string { + n := d.u32() + if d.err != nil { + return "" + } + buf := make([]byte, n) + _, d.err = d.r.Read(buf) + return string(buf) +} + +func (d *binarySummaryDecoder) bool() bool { + if d.err != nil { + return false + } + b, err := d.r.ReadByte() + if err != nil { + d.err = err + return false + } + return b != 0 +} + +func (d *binarySummaryDecoder) strings() []string { + n := d.u32() + if d.err != nil { + return nil + } + out := make([]string, 0, n) + for i := uint32(0); i < n; i++ { + out = append(out, d.string()) + } + return out +} + +func (d *binarySummaryDecoder) providerSets() []providerSetSummary { + n := d.u32() + if d.err != nil { + return nil + } + out := make([]providerSetSummary, 0, n) + for i := uint32(0); i < n; i++ { + out = append(out, d.providerSet()) + } + return out +} + +func (d *binarySummaryDecoder) providerSet() providerSetSummary { + value := providerSetSummary{ + VarName: d.string(), + } + for n := d.u32(); n > 0; n-- { + provider := providerSummary{ + PkgPath: d.string(), + Name: d.string(), + } + for m := d.u32(); m > 0; m-- { + provider.Args = append(provider.Args, providerInputSummary{ + Type: d.string(), + FieldName: d.string(), + }) + } + provider.Out = d.strings() + provider.Varargs = d.bool() + provider.IsStruct = d.bool() + provider.HasCleanup = d.bool() + provider.HasErr = d.bool() + value.Providers = append(value.Providers, provider) + } + for n := d.u32(); n > 0; n-- { + value.Imports = append(value.Imports, providerSetRefSummary{ + PkgPath: d.string(), + VarName: d.string(), + }) + } + for n := d.u32(); n > 0; n-- { + value.Bindings = append(value.Bindings, ifaceBindingSummary{ + Iface: d.string(), + Provided: d.string(), + }) + } + value.Values = d.strings() + for n := d.u32(); n > 0; n-- { + value.Fields = append(value.Fields, fieldSummary{ + PkgPath: d.string(), + Parent: d.string(), + Name: d.string(), + Out: d.strings(), + }) + } + value.InputTypes = d.strings() + return value +} diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go new file mode 100644 index 0000000..efb4028 --- /dev/null +++ b/internal/wire/incremental_summary_test.go @@ -0,0 +1,287 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestIncrementalSummaryEncodeDecodeRoundTrip(t *testing.T) { + summary := &packageSummary{ + Version: incrementalSummaryVersion, + WD: "/tmp/app", + Tags: "dev", + PkgPath: "example.com/app/dep", + ShapeHash: "abc123", + LocalImports: []string{"example.com/app/shared"}, + ProviderSets: []providerSetSummary{{ + VarName: "Set", + Providers: []providerSummary{{ + PkgPath: "example.com/app/dep", + Name: "NewThing", + Args: []providerInputSummary{{Type: "string"}}, + Out: []string{"*example.com/app/dep.Thing"}, + HasCleanup: true, + }}, + Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, + Bindings: []ifaceBindingSummary{{Iface: "error", Provided: "*example.com/app/dep.Thing"}}, + Values: []string{"string"}, + Fields: []fieldSummary{{PkgPath: "example.com/app/dep", Parent: "example.com/app/dep.Config", Name: "Name", Out: []string{"string"}}}, + InputTypes: []string{"context.Context"}, + }}, + Injectors: []injectorSummary{{ + Name: "Init", + Inputs: []string{"context.Context"}, + Output: "*example.com/app/dep.Thing", + Build: providerSetSummary{ + Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, + }, + }}, + } + data, err := encodeIncrementalSummary(summary) + if err != nil { + t.Fatalf("encodeIncrementalSummary: %v", err) + } + got, err := decodeIncrementalSummary(data) + if err != nil { + t.Fatalf("decodeIncrementalSummary: %v", err) + } + if got.Version != summary.Version || got.PkgPath != summary.PkgPath || got.ShapeHash != summary.ShapeHash { + t.Fatalf("decoded summary mismatch: %+v", got) + } + if len(got.ProviderSets) != 1 || got.ProviderSets[0].VarName != "Set" { + t.Fatalf("decoded provider sets mismatch: %+v", got.ProviderSets) + } + if len(got.Injectors) != 1 || got.Injectors[0].Name != "Init" { + t.Fatalf("decoded injectors mismatch: %+v", got.Injectors) + } +} + +func TestBuildPackageSummary(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo { return &Foo{Message: msg} }", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + oc := newObjectCache(pkgs, loader) + loadedDep, errs := oc.ensurePackage("example.com/app/dep") + if len(errs) > 0 { + t.Fatalf("ensurePackage returned errors: %v", errs) + } + summary, err := buildPackageSummary(loader, oc, loadedDep) + if err != nil { + t.Fatalf("buildPackageSummary: %v", err) + } + if summary.PkgPath != "example.com/app/dep" { + t.Fatalf("summary pkg path = %q", summary.PkgPath) + } + if len(summary.ProviderSets) != 1 || summary.ProviderSets[0].VarName != "Set" { + t.Fatalf("unexpected provider sets: %+v", summary.ProviderSets) + } + if len(summary.ProviderSets[0].Providers) != 2 { + t.Fatalf("unexpected providers: %+v", summary.ProviderSets[0].Providers) + } + loadedApp, errs := oc.ensurePackage("example.com/app/app") + if len(errs) > 0 { + t.Fatalf("ensurePackage app returned errors: %v", errs) + } + appSummary, err := buildPackageSummary(loader, oc, loadedApp) + if err != nil { + t.Fatalf("buildPackageSummary app: %v", err) + } + if len(appSummary.Injectors) != 1 || appSummary.Injectors[0].Name != "Init" { + t.Fatalf("unexpected injectors: %+v", appSummary.Injectors) + } + if len(appSummary.Injectors[0].Build.Imports) != 1 || appSummary.Injectors[0].Build.Imports[0].PkgPath != "example.com/app/dep" { + t.Fatalf("unexpected injector imports: %+v", appSummary.Injectors[0].Build.Imports) + } +} + +func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo { return &Foo{Message: msg} }", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate result: %+v", gens) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct{ Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo { return &Foo{Message: msg, Count: count} }", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + snapshot := collectIncrementalPackageSummaries(loader, pkgs) + if snapshot == nil { + t.Fatal("collectIncrementalPackageSummaries returned nil") + } + if _, ok := snapshot.Changed["example.com/app/dep"]; !ok { + t.Fatalf("expected changed dep summary, got %+v", snapshot.Changed) + } + if _, ok := snapshot.Unchanged["example.com/app/app"]; !ok { + t.Fatalf("expected unchanged app summary from cache, got %+v", snapshot.Unchanged) + } + if len(snapshot.Unchanged["example.com/app/app"].Injectors) != 1 { + t.Fatalf("unexpected cached app summary: %+v", snapshot.Unchanged["example.com/app/app"]) + } + if len(snapshot.Changed["example.com/app/dep"].ProviderSets) != 1 { + t.Fatalf("unexpected changed dep summary: %+v", snapshot.Changed["example.com/app/dep"]) + } +} diff --git a/internal/wire/incremental_test.go b/internal/wire/incremental_test.go new file mode 100644 index 0000000..a531123 --- /dev/null +++ b/internal/wire/incremental_test.go @@ -0,0 +1,65 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "testing" +) + +func TestIncrementalEnabledDefaultOff(t *testing.T) { + if IncrementalEnabled(context.Background(), nil) { + t.Fatal("IncrementalEnabled should default to false") + } +} + +func TestIncrementalEnabledFromEnv(t *testing.T) { + env := []string{ + "FOO=bar", + IncrementalEnvVar + "=true", + } + if !IncrementalEnabled(context.Background(), env) { + t.Fatal("IncrementalEnabled should read the environment variable") + } +} + +func TestIncrementalEnabledUsesLastEnvValue(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=false", + IncrementalEnvVar + "=true", + } + if !IncrementalEnabled(context.Background(), env) { + t.Fatal("IncrementalEnabled should use the last matching env value") + } +} + +func TestIncrementalEnabledContextOverridesEnv(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=false", + } + ctx := WithIncremental(context.Background(), true) + if !IncrementalEnabled(ctx, env) { + t.Fatal("context override should take precedence over env") + } +} + +func TestIncrementalEnabledInvalidEnvFallsBackFalse(t *testing.T) { + env := []string{ + IncrementalEnvVar + "=maybe", + } + if IncrementalEnabled(context.Background(), env) { + t.Fatal("invalid env value should not enable incremental mode") + } +} diff --git a/internal/wire/load_debug.go b/internal/wire/load_debug.go new file mode 100644 index 0000000..fd8c4d7 --- /dev/null +++ b/internal/wire/load_debug.go @@ -0,0 +1,304 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + "time" + + "golang.org/x/tools/go/packages" +) + +type parseFileStats struct { + mu sync.Mutex + calls int + primaryCalls int + depCalls int + cacheHits int + cacheMisses int + errors int + total time.Duration +} + +func (ps *parseFileStats) record(primary bool, dur time.Duration, err error, cacheHit bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + ps.calls++ + if primary { + ps.primaryCalls++ + } else { + ps.depCalls++ + } + if cacheHit { + ps.cacheHits++ + } else { + ps.cacheMisses++ + } + ps.total += dur + if err != nil { + ps.errors++ + } +} + +func (ps *parseFileStats) snapshot() parseFileStats { + ps.mu.Lock() + defer ps.mu.Unlock() + return parseFileStats{ + calls: ps.calls, + primaryCalls: ps.primaryCalls, + depCalls: ps.depCalls, + cacheHits: ps.cacheHits, + cacheMisses: ps.cacheMisses, + errors: ps.errors, + total: ps.total, + } +} + +type loadScopeStats struct { + roots int + totalPackages int + compiledFiles int + syntaxFiles int + packagesWithSyntax int + packagesWithTypes int + packagesWithTypesInfo int + localPackages int + localSyntaxPackages int + externalPackages int + externalSyntaxPkgs int + unknownPackages int + topCompiled []string + topSyntax []string +} + +type packageMetric struct { + path string + count int +} + +func logLoadDebug(ctx context.Context, scope string, mode packages.LoadMode, subject string, wd string, pkgs []*packages.Package, parseStats *parseFileStats) { + if timing(ctx) == nil { + return + } + stats := summarizeLoadScope(wd, pkgs) + debugf(ctx, "load.debug scope=%s subject=%s mode=%s roots=%d total_pkgs=%d compiled_files=%d syntax_files=%d syntax_pkgs=%d typed_pkgs=%d types_info_pkgs=%d local_pkgs=%d local_syntax_pkgs=%d external_pkgs=%d external_syntax_pkgs=%d unknown_pkgs=%d", + scope, + subject, + formatLoadMode(mode), + stats.roots, + stats.totalPackages, + stats.compiledFiles, + stats.syntaxFiles, + stats.packagesWithSyntax, + stats.packagesWithTypes, + stats.packagesWithTypesInfo, + stats.localPackages, + stats.localSyntaxPackages, + stats.externalPackages, + stats.externalSyntaxPkgs, + stats.unknownPackages, + ) + if len(stats.topCompiled) > 0 { + debugf(ctx, "load.debug scope=%s top_compiled_files=%s", scope, strings.Join(stats.topCompiled, ", ")) + } + if len(stats.topSyntax) > 0 { + debugf(ctx, "load.debug scope=%s top_syntax_files=%s", scope, strings.Join(stats.topSyntax, ", ")) + } + if parseStats != nil { + snap := parseStats.snapshot() + debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.total=%s", + scope, + snap.calls, + snap.primaryCalls, + snap.depCalls, + snap.cacheHits, + snap.cacheMisses, + snap.errors, + snap.total, + ) + } +} + +func summarizeLoadScope(wd string, pkgs []*packages.Package) loadScopeStats { + all := collectAllPackages(pkgs) + stats := loadScopeStats{ + roots: len(pkgs), + totalPackages: len(all), + } + moduleRoot := findModuleRoot(wd) + var compiled []packageMetric + var syntax []packageMetric + for _, pkg := range all { + if pkg == nil { + continue + } + compiledCount := len(pkg.CompiledGoFiles) + syntaxCount := len(pkg.Syntax) + stats.compiledFiles += compiledCount + stats.syntaxFiles += syntaxCount + if syntaxCount > 0 { + stats.packagesWithSyntax++ + } + if pkg.Types != nil { + stats.packagesWithTypes++ + } + if pkg.TypesInfo != nil { + stats.packagesWithTypesInfo++ + } + class := classifyPackageLocation(moduleRoot, pkg) + switch class { + case "local": + stats.localPackages++ + if syntaxCount > 0 { + stats.localSyntaxPackages++ + } + case "external": + stats.externalPackages++ + if syntaxCount > 0 { + stats.externalSyntaxPkgs++ + } + default: + stats.unknownPackages++ + } + if compiledCount > 0 { + compiled = append(compiled, packageMetric{path: pkg.PkgPath, count: compiledCount}) + } + if syntaxCount > 0 { + syntax = append(syntax, packageMetric{path: pkg.PkgPath, count: syntaxCount}) + } + } + stats.topCompiled = topPackageMetrics(compiled) + stats.topSyntax = topPackageMetrics(syntax) + return stats +} + +func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { + if moduleRoot == "" || pkg == nil { + return "unknown" + } + for _, name := range pkg.CompiledGoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + for _, name := range pkg.GoFiles { + if isWithinRoot(moduleRoot, name) { + return "local" + } + return "external" + } + return "unknown" +} + +func isWithinRoot(root, name string) bool { + cleanRoot := filepath.Clean(root) + cleanName := filepath.Clean(name) + if cleanName == cleanRoot { + return true + } + rel, err := filepath.Rel(cleanRoot, cleanName) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) +} + +func topPackageMetrics(metrics []packageMetric) []string { + sort.Slice(metrics, func(i, j int) bool { + if metrics[i].count == metrics[j].count { + return metrics[i].path < metrics[j].path + } + return metrics[i].count > metrics[j].count + }) + if len(metrics) > 5 { + metrics = metrics[:5] + } + out := make([]string, 0, len(metrics)) + for _, m := range metrics { + out = append(out, fmt.Sprintf("%s(%d)", m.path, m.count)) + } + return out +} + +func findModuleRoot(wd string) string { + dir := filepath.Clean(wd) + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + return "" + } + dir = parent + } +} + +func formatLoadMode(mode packages.LoadMode) string { + flags := []struct { + bit packages.LoadMode + name string + }{ + {packages.NeedName, "NeedName"}, + {packages.NeedFiles, "NeedFiles"}, + {packages.NeedCompiledGoFiles, "NeedCompiledGoFiles"}, + {packages.NeedImports, "NeedImports"}, + {packages.NeedDeps, "NeedDeps"}, + {packages.NeedExportsFile, "NeedExportsFile"}, + {packages.NeedTypes, "NeedTypes"}, + {packages.NeedSyntax, "NeedSyntax"}, + {packages.NeedTypesInfo, "NeedTypesInfo"}, + {packages.NeedTypesSizes, "NeedTypesSizes"}, + {packages.NeedModule, "NeedModule"}, + {packages.NeedEmbedFiles, "NeedEmbedFiles"}, + {packages.NeedEmbedPatterns, "NeedEmbedPatterns"}, + } + var parts []string + for _, flag := range flags { + if mode&flag.bit != 0 { + parts = append(parts, flag.name) + } + } + if len(parts) == 0 { + return "0" + } + return strings.Join(parts, "|") +} + +func primaryFileSet(files map[string]struct{}) map[string]struct{} { + if len(files) == 0 { + return nil + } + out := make(map[string]struct{}, len(files)) + for name := range files { + out[filepath.Clean(name)] = struct{}{} + } + return out +} + +func isPrimaryFile(primary map[string]struct{}, filename string) bool { + if len(primary) == 0 { + return false + } + _, ok := primary[filepath.Clean(filename)] + return ok +} diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 1fbd96c..2f41c8d 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -20,6 +20,7 @@ import ( "path/filepath" "strings" "testing" + "time" ) func TestLoadAndGenerateModule(t *testing.T) { @@ -124,6 +125,925 @@ func TestLoadAndGenerateModule(t *testing.T) { } } +func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + + info, errs := Load(context.Background(), root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("Load returned errors: %v", errs) + } + if info == nil || len(info.Injectors) != 1 { + t.Fatalf("Load returned unexpected info: %+v errs=%v", info, errs) + } + + incrementalCtx := WithIncremental(context.Background(), true) + incrementalInfo, errs := Load(incrementalCtx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("incremental Load returned errors: %v", errs) + } + if incrementalInfo == nil || len(incrementalInfo.Injectors) != 1 { + t.Fatalf("incremental Load returned unexpected info: %+v errs=%v", incrementalInfo, errs) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + incrementalGens, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(incrementalGens) != 1 { + t.Fatalf("unexpected result counts: normal=%d incremental=%d", len(normalGens), len(incrementalGens)) + } + if len(normalGens[0].Errs) > 0 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected generate errors: normal=%v incremental=%v", normalGens[0].Errs, incrementalGens[0].Errs) + } + if normalGens[0].OutputPath != incrementalGens[0].OutputPath { + t.Fatalf("output paths differ: normal=%q incremental=%q", normalGens[0].OutputPath, incrementalGens[0].OutputPath) + } + if string(normalGens[0].Content) != string(incrementalGens[0].Content) { + t.Fatalf("generated content differs between normal and incremental modes") + } +} + +func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "app", "wire_gen.go"), strings.Join([]string{ + "//go:build !wireinject", + "", + "package app", + "", + "func generated() {}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "app", "app_test.go"), strings.Join([]string{ + "package app", + "", + "func testOnly() {}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + var firstLabels []string + firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { + firstLabels = append(firstLabels, label) + }) + first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + if !containsLabel(firstLabels, "load.packages.lazy.load") { + t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) + } + + if err := os.WriteFile(depFile, []byte(strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"b\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to hit preload incremental manifest before package load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected second Generate to skip lazy load, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected body-only change to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + var firstLabels []string + firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { + firstLabels = append(firstLabels, label) + }) + first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + if !containsLabel(firstLabels, "load.packages.lazy.load") { + t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected shape-changing incremental run to skip package load via local fast path, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected shape-changing incremental run to skip lazy load via local fast path, labels=%v", secondLabels) + } + if !containsLabel(secondLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected shape-changing incremental run to use local fast path, labels=%v", secondLabels) + } + if string(first[0].Content) == string(second[0].Content) { + t.Fatal("expected shape-changing edit to regenerate different output") + } +} + +func TestGenerateIncrementalRepeatedShapeStateHitsPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected repeated shape state to hit preload manifest before package load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected repeated shape state to skip lazy load, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected repeated shape state to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeThenRepeatHitsPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ + "package extra", + "", + "type Marker struct{}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected shape-changing Generate to skip package load via local fast path, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "load.packages.lazy.load") { + t.Fatalf("expected shape-changing Generate to skip lazy load via local fast path, labels=%v", secondLabels) + } + if !containsLabel(secondLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected shape-changing Generate to use local fast path, labels=%v", secondLabels) + } + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + if containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected repeated shape-changing state to hit preload manifest before package load, labels=%v", thirdLabels) + } + if containsLabel(thirdLabels, "load.packages.lazy.load") { + t.Fatalf("expected repeated shape-changing state to skip lazy load, labels=%v", thirdLabels) + } + if string(second[0].Content) != string(third[0].Content) { + t.Fatal("expected repeated shape-changing state to reuse identical generated output") + } +} + +func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeIncrementalBenchmarkModule(t, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental shape-change Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected incremental shape-change Generate to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate results: %+v", normalGens) + } + if incrementalGens[0].OutputPath != normalGens[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("shape-changing incremental output differs from normal Generate output") + } +} + +func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid incremental generate to return errors") + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected invalid incremental generate to stop before slow-path load, labels=%v", secondLabels) + } + if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { + t.Fatalf("expected fast-path type-check error, got %q", got) + } +} + +func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + + oldDep := strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n") + newDep := strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n") + oldWire := strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n") + newWire := strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n") + + writeFile(t, depFile, oldDep) + writeFile(t, wireFile, oldWire) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, newDep) + writeFile(t, wireFile, newWire) + second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + + writeFile(t, depFile, oldDep) + writeFile(t, wireFile, oldWire) + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + if containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected toggled-back shape state to hit archived preload manifest before package load, labels=%v", thirdLabels) + } + if containsLabel(thirdLabels, "load.packages.lazy.load") { + t.Fatalf("expected toggled-back shape state to skip lazy load, labels=%v", thirdLabels) + } + if string(first[0].Content) != string(third[0].Content) { + t.Fatal("expected toggled-back shape state to reuse archived generated output") + } +} + +func containsLabel(labels []string, want string) bool { + for _, label := range labels { + if label == want { + return true + } + } + return false +} + func mustRepoRoot(t *testing.T) string { t.Helper() wd, err := os.Getwd() diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go new file mode 100644 index 0000000..466dcc2 --- /dev/null +++ b/internal/wire/local_fastpath.go @@ -0,0 +1,556 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "go/ast" + "go/format" + importerpkg "go/importer" + "go/parser" + "go/token" + "go/types" + "io" + "os" + "path/filepath" + "runtime" + "sort" + "strings" + "time" + + "golang.org/x/tools/go/packages" +) + +func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState) ([]GenerateResult, bool, bool, []error) { + if state == nil || state.manifest == nil { + return nil, false, false, nil + } + if !strings.HasSuffix(state.reason, ".shape_mismatch") { + return nil, false, false, nil + } + roots := manifestOutputPkgPaths(state.manifest) + if len(roots) != 1 { + return nil, false, false, nil + } + changed := changedPackagePaths(state.manifest.LocalPackages, state.currentLocal) + if len(changed) != 1 { + return nil, false, false, nil + } + graph, ok := readIncrementalGraph(incrementalGraphKey(wd, opts.Tags, roots)) + if !ok { + return nil, false, false, nil + } + affected := affectedRoots(graph, changed) + if len(affected) != 1 || affected[0] != roots[0] { + return nil, false, false, nil + } + + fastPathStart := time.Now() + loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], state.currentLocal, state.manifest.ExternalPkgs) + if err != nil { + debugf(ctx, "incremental.local_fastpath miss reason=%v", err) + if shouldBypassIncrementalManifestAfterFastPathError(err) { + return nil, true, true, []error{err} + } + return nil, false, false, nil + } + logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) + + generated, errs := generateFromTypedPackages(ctx, loaded.root, loaded.allPackages, opts) + logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) + if len(errs) > 0 { + return nil, true, true, errs + } + + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: loaded.fingerprints, + changed: append([]string(nil), changed...), + } + loader := &lazyLoader{ + ctx: ctx, + wd: wd, + env: env, + tags: opts.Tags, + fset: loaded.fset, + fingerprints: snapshot, + loaded: make(map[string]*packages.Package, len(loaded.byPath)), + } + for path, pkg := range loaded.byPath { + loader.loaded[path] = pkg + } + writeIncrementalFingerprints(snapshot, wd, opts.Tags) + writeIncrementalPackageSummaries(loader, loaded.allPackages) + writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) + writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) + + debugf(ctx, "incremental.local_fastpath hit root=%s changed=%s", roots[0], strings.Join(changed, ",")) + return generated, true, false, nil +} + +func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "type-check failed for ") +} + +func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { + if len(errs) == 0 { + return fmt.Errorf("type-check failed for %s", pkgPath) + } + root := findModuleRoot(wd) + lines := []string{} + for _, pkgErr := range errs { + details := normalizeErrorLines(pkgErr.Msg, root) + if len(details) == 0 { + continue + } + lines = append(lines, fmt.Sprintf("type-check failed for %s: %s", pkgPath, details[0])) + for _, line := range details[1:] { + lines = append(lines, line) + } + } + if len(lines) == 0 { + lines = append(lines, fmt.Sprintf("type-check failed for %s", pkgPath)) + } + return fmt.Errorf("%s", strings.Join(lines, "\n")) +} + +func normalizeErrorLines(msg string, root string) []string { + msg = strings.TrimSpace(msg) + if msg == "" { + return []string{"unknown error"} + } + lines := unfoldTypeCheckChain(msg) + for i := range lines { + lines[i] = relativizeErrorLine(lines[i], root) + } + if len(lines) == 0 { + return []string{"unknown error"} + } + return lines +} + +func relativizeErrorLine(line string, root string) string { + if root == "" { + return line + } + cleanRoot := filepath.Clean(root) + prefix := cleanRoot + string(os.PathSeparator) + return strings.ReplaceAll(line, prefix, "") +} + +func unfoldTypeCheckChain(msg string) []string { + msg = strings.TrimSpace(msg) + if msg == "" { + return nil + } + if inner, outer, ok := splitNestedTypeCheck(msg); ok { + lines := []string{strings.TrimSpace(outer)} + return append(lines, unfoldTypeCheckChain(inner)...) + } + parts := strings.Split(msg, "\n") + lines := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + lines = append(lines, part) + } + return lines +} + +func splitNestedTypeCheck(msg string) (inner string, outer string, ok bool) { + msg = strings.TrimSpace(msg) + if len(msg) < 2 || msg[len(msg)-1] != ')' { + return "", "", false + } + depth := 0 + for i := len(msg) - 1; i >= 0; i-- { + switch msg[i] { + case ')': + depth++ + case '(': + depth-- + if depth == 0 { + inner = strings.TrimSpace(msg[i+1 : len(msg)-1]) + if strings.HasPrefix(inner, "type-check failed for ") { + return inner, strings.TrimSpace(msg[:i]), true + } + return "", "", false + } + } + } + return "", "", false +} + +type localFastPathLoaded struct { + fset *token.FileSet + root *packages.Package + allPackages []*packages.Package + byPath map[string]*packages.Package + fingerprints map[string]*packageFingerprint +} + +type localFastPathLoader struct { + ctx context.Context + wd string + tags string + fset *token.FileSet + rootPkgPath string + meta map[string]*packageFingerprint + pkgs map[string]*packages.Package + externalMeta map[string]externalPackageExport + externalImp types.Importer +} + +func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { + meta := fingerprintsFromSlice(current) + if len(meta) == 0 { + return nil, fmt.Errorf("no local fingerprints") + } + if meta[rootPkgPath] == nil { + return nil, fmt.Errorf("missing root package fingerprint") + } + externalMeta := make(map[string]externalPackageExport, len(external)) + for _, item := range external { + if item.PkgPath == "" || item.ExportFile == "" { + continue + } + externalMeta[item.PkgPath] = item + } + loader := &localFastPathLoader{ + ctx: ctx, + wd: wd, + tags: tags, + fset: token.NewFileSet(), + rootPkgPath: rootPkgPath, + meta: meta, + pkgs: make(map[string]*packages.Package, len(meta)), + externalMeta: externalMeta, + } + loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) + root, err := loader.load(rootPkgPath) + if err != nil { + return nil, err + } + all := make([]*packages.Package, 0, len(loader.pkgs)) + for _, pkg := range loader.pkgs { + all = append(all, pkg) + } + sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) + return &localFastPathLoaded{ + fset: loader.fset, + root: root, + allPackages: all, + byPath: loader.pkgs, + fingerprints: loader.meta, + }, nil +} + +func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { + if pkg := l.pkgs[pkgPath]; pkg != nil { + return pkg, nil + } + fp := l.meta[pkgPath] + if fp == nil { + return nil, fmt.Errorf("package %s not tracked as local", pkgPath) + } + files := filesFromMeta(fp.Files) + if len(files) == 0 { + return nil, fmt.Errorf("package %s has no files", pkgPath) + } + mode := parser.ParseComments | parser.SkipObjectResolution + syntax := make([]*ast.File, 0, len(files)) + for _, name := range files { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + syntax = append(syntax, file) + } + if len(syntax) == 0 { + return nil, fmt.Errorf("package %s parsed no files", pkgPath) + } + + pkgName := syntax[0].Name.Name + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + Scopes: make(map[ast.Node]*types.Scope), + Instances: make(map[*ast.Ident]types.Instance), + } + pkg := &packages.Package{ + Fset: l.fset, + Name: pkgName, + PkgPath: pkgPath, + GoFiles: append([]string(nil), files...), + CompiledGoFiles: append([]string(nil), files...), + Syntax: syntax, + TypesInfo: info, + Imports: make(map[string]*packages.Package), + } + l.pkgs[pkgPath] = pkg + + conf := &types.Config{ + Importer: importerFunc(func(path string) (*types.Package, error) { + return l.importPackage(path) + }), + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) + }, + } + checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) + if checkedPkg != nil { + pkg.Types = checkedPkg + } + if err != nil && len(pkg.Errors) == 0 { + return nil, err + } + if len(pkg.Errors) > 0 { + return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) + } + + imports := packageImportPaths(syntax) + localImports := make([]string, 0, len(imports)) + for _, path := range imports { + if dep := l.pkgs[path]; dep != nil { + pkg.Imports[path] = dep + localImports = append(localImports, path) + } + } + sort.Strings(localImports) + updated := *fp + updated.LocalImports = localImports + updated.Tags = l.tags + updated.WD = filepath.Clean(l.wd) + l.meta[pkgPath] = &updated + return pkg, nil +} + +func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { + if l.meta[path] != nil { + pkg, err := l.load(path) + if err != nil { + return nil, err + } + return pkg.Types, nil + } + if l.externalImp == nil { + return nil, fmt.Errorf("missing external importer") + } + return l.externalImp.Import(path) +} + +func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { + meta, ok := l.externalMeta[path] + if !ok || meta.ExportFile == "" { + return nil, fmt.Errorf("missing export data for %s", path) + } + return os.Open(meta.ExportFile) +} + +type importerFunc func(string) (*types.Package, error) + +func (fn importerFunc) Import(path string) (*types.Package, error) { + return fn(path) +} + +func packageImportPaths(files []*ast.File) []string { + seen := make(map[string]struct{}) + var out []string + for _, file := range files { + for _, spec := range file.Imports { + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + continue + } + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + out = append(out, path) + } + } + sort.Strings(out) + return out +} + +func generateFromTypedPackages(ctx context.Context, root *packages.Package, allPkgs []*packages.Package, opts *GenerateOptions) ([]GenerateResult, []error) { + if root == nil { + return nil, []error{fmt.Errorf("missing root package")} + } + if opts == nil { + opts = &GenerateOptions{} + } + pkgStart := time.Now() + res := GenerateResult{PkgPath: root.PkgPath} + outDir, err := detectOutputDir(root.GoFiles) + logTiming(ctx, "generate.package."+root.PkgPath+".output_dir", pkgStart) + if err != nil { + res.Errs = append(res.Errs, err) + return []GenerateResult{res}, nil + } + res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") + + oc := newObjectCache(allPkgs, nil) + g := newGen(root) + injectorStart := time.Now() + injectorFiles, errs := generateInjectors(oc, g, root) + logTiming(ctx, "generate.package."+root.PkgPath+".injectors", injectorStart) + if len(errs) > 0 { + res.Errs = errs + return []GenerateResult{res}, nil + } + copyStart := time.Now() + copyNonInjectorDecls(g, injectorFiles, root.TypesInfo) + logTiming(ctx, "generate.package."+root.PkgPath+".copy_non_injectors", copyStart) + frameStart := time.Now() + goSrc := g.frame(opts.Tags) + logTiming(ctx, "generate.package."+root.PkgPath+".frame", frameStart) + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } + formatStart := time.Now() + fmtSrc, err := format.Source(goSrc) + logTiming(ctx, "generate.package."+root.PkgPath+".format", formatStart) + if err != nil { + res.Errs = append(res.Errs, err) + } else { + goSrc = fmtSrc + } + res.Content = goSrc + logTiming(ctx, "generate.package."+root.PkgPath+".total", pkgStart) + return []GenerateResult{res}, nil +} + +func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd string, tags string) { + if snapshot == nil { + return + } + for _, fp := range snapshotPackageFingerprints(snapshot) { + fp := fp + writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), &fp) + } +} + +func writeIncrementalManifestFromState(wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { + return + } + manifest := &incrementalManifest{ + Version: incrementalManifestVersion, + WD: filepath.Clean(wd), + Tags: opts.Tags, + Prefix: opts.PrefixOutputFile, + HeaderHash: headerHash(opts.Header), + EnvHash: envHash(env), + Patterns: sortedStrings(patterns), + LocalPackages: snapshotPackageFingerprints(snapshot), + ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), + ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), + ExtraFiles: extraCacheFiles(wd), + } + for _, out := range generated { + if len(out.Content) == 0 || out.OutputPath == "" { + continue + } + contentKey := incrementalContentKey(out.Content) + writeCache(contentKey, out.Content) + manifest.Outputs = append(manifest.Outputs, incrementalOutput{ + PkgPath: out.PkgPath, + OutputPath: out.OutputPath, + ContentKey: contentKey, + }) + } + if len(manifest.Outputs) == 0 { + return + } + selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) + writeIncrementalManifestFile(selectorKey, manifest) + writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, manifest.LocalPackages), manifest) +} + +func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, fps map[string]*packageFingerprint) { + if len(roots) == 0 || len(fps) == 0 { + return + } + graph := &incrementalGraph{ + Version: incrementalGraphVersion, + WD: filepath.Clean(wd), + Tags: tags, + Roots: append([]string(nil), roots...), + LocalReverse: make(map[string][]string), + } + sort.Strings(graph.Roots) + for _, fp := range fps { + if fp == nil { + continue + } + for _, imp := range fp.LocalImports { + graph.LocalReverse[imp] = append(graph.LocalReverse[imp], fp.PkgPath) + } + } + for path := range graph.LocalReverse { + sort.Strings(graph.LocalReverse[path]) + } + writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) +} + +func manifestOutputPkgPaths(manifest *incrementalManifest) []string { + if manifest == nil || len(manifest.Outputs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(manifest.Outputs)) + paths := make([]string, 0, len(manifest.Outputs)) + for _, out := range manifest.Outputs { + if out.PkgPath == "" { + continue + } + if _, ok := seen[out.PkgPath]; ok { + continue + } + seen[out.PkgPath] = struct{}{} + paths = append(paths, out.PkgPath) + } + sort.Strings(paths) + return paths +} + +func changedPackagePaths(previous []packageFingerprint, current []packageFingerprint) []string { + if len(current) == 0 { + return nil + } + prevByPath := make(map[string]packageFingerprint, len(previous)) + for _, fp := range previous { + prevByPath[fp.PkgPath] = fp + } + changed := make([]string, 0, len(current)) + for _, fp := range current { + prev, ok := prevByPath[fp.PkgPath] + if !ok || !incrementalFingerprintEquivalent(&prev, &fp) { + changed = append(changed, fp.PkgPath) + } + } + sort.Strings(changed) + return changed +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index fc1b353..2f038a9 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -250,6 +250,9 @@ type Field struct { // In case of duplicate environment variables, the last one in the list // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { + if IncrementalEnabled(ctx, env) { + debugf(ctx, "incremental=enabled") + } loadStart := time.Now() pkgs, loader, errs := load(ctx, wd, env, tags, patterns) logTiming(ctx, "load.packages", loadStart) @@ -365,7 +368,13 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // In case of duplicate environment variables, the last one in the list // takes precedence. func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, *lazyLoader, []error) { + var session *incrementalSession fset := token.NewFileSet() + if IncrementalEnabled(ctx, env) { + session = getIncrementalSession(wd, env, tags) + fset = session.fset + debugf(ctx, "incremental session=enabled") + } baseCfg := &packages.Config{ Context: ctx, Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps, @@ -384,6 +393,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] baseLoadStart := time.Now() pkgs, err := packages.Load(baseCfg, escaped...) logTiming(ctx, "load.packages.base.load", baseLoadStart) + logLoadDebug(ctx, "base", baseCfg.Mode, strings.Join(patterns, ","), wd, pkgs, nil) if err != nil { return nil, nil, []error{err} } @@ -393,15 +403,19 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] if len(errs) > 0 { return nil, nil, errs } + fingerprints := analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) + analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) baseFiles := collectPackageFiles(pkgs) loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: tags, - fset: fset, - baseFiles: baseFiles, + ctx: ctx, + wd: wd, + env: env, + tags: tags, + fset: fset, + baseFiles: baseFiles, + session: session, + fingerprints: fingerprints, } return pkgs, loader, nil } diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go index b3d7011..223c9ad 100644 --- a/internal/wire/parser_lazy_loader.go +++ b/internal/wire/parser_lazy_loader.go @@ -26,12 +26,15 @@ import ( ) type lazyLoader struct { - ctx context.Context - wd string - env []string - tags string - fset *token.FileSet - baseFiles map[string]map[string]struct{} + ctx context.Context + wd string + env []string + tags string + fset *token.FileSet + baseFiles map[string]map[string]struct{} + session *incrementalSession + fingerprints *incrementalFingerprintSnapshot + loaded map[string]*packages.Package } func collectPackageFiles(pkgs []*packages.Package) map[string]map[string]struct{} { @@ -74,10 +77,11 @@ func (ll *lazyLoader) load(pkgPath string) ([]*packages.Package, []error) { } func (ll *lazyLoader) fullMode() packages.LoadMode { - return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax + return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile } func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timingLabel string) ([]*packages.Package, []error) { + parseStats := &parseFileStats{} cfg := &packages.Config{ Context: ll.ctx, Mode: mode, @@ -85,7 +89,7 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin Env: ll.env, BuildFlags: []string{"-tags=wireinject"}, Fset: ll.fset, - ParseFile: ll.parseFileFor(pkgPath), + ParseFile: ll.parseFileFor(pkgPath, parseStats), } if len(ll.tags) > 0 { cfg.BuildFlags[0] += " " + ll.tags @@ -93,6 +97,7 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin loadStart := time.Now() pkgs, err := packages.Load(cfg, "pattern="+pkgPath) logTiming(ll.ctx, timingLabel, loadStart) + logLoadDebug(ll.ctx, "lazy", mode, pkgPath, ll.wd, pkgs, parseStats) if err != nil { return nil, []error{err} } @@ -100,26 +105,52 @@ func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timin if len(errs) > 0 { return nil, errs } + ll.rememberPackages(pkgs) return pkgs, nil } -func (ll *lazyLoader) parseFileFor(pkgPath string) func(*token.FileSet, string, []byte) (*ast.File, error) { - primary := ll.baseFiles[pkgPath] +func (ll *lazyLoader) rememberPackages(pkgs []*packages.Package) { + if ll == nil || len(pkgs) == 0 { + return + } + if ll.loaded == nil { + ll.loaded = make(map[string]*packages.Package) + } + for path, pkg := range collectAllPackages(pkgs) { + if pkg != nil { + ll.loaded[path] = pkg + } + } +} + +func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(*token.FileSet, string, []byte) (*ast.File, error) { + primary := primaryFileSet(ll.baseFiles[pkgPath]) return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - mode := parser.SkipObjectResolution - if primary != nil { - if _, ok := primary[filepath.Clean(filename)]; ok { - mode = parser.ParseComments | parser.SkipObjectResolution + start := time.Now() + isPrimary := isPrimaryFile(primary, filename) + if !isPrimary && ll.session != nil { + if file, ok := ll.session.getParsedDep(filename, src); ok { + if stats != nil { + stats.record(false, time.Since(start), nil, true) + } + return file, nil } } + mode := parser.SkipObjectResolution + if isPrimary { + mode = parser.ParseComments | parser.SkipObjectResolution + } file, err := parser.ParseFile(fset, filename, src, mode) + if stats != nil { + stats.record(isPrimary, time.Since(start), err, false) + } if err != nil { return nil, err } if primary == nil { return file, nil } - if _, ok := primary[filepath.Clean(filename)]; ok { + if isPrimary { return file, nil } for _, decl := range file.Decls { @@ -128,6 +159,9 @@ func (ll *lazyLoader) parseFileFor(pkgPath string) func(*token.FileSet, string, fn.Doc = nil } } + if ll.session != nil { + ll.session.storeParsedDep(filename, src, file) + } return file, nil } } diff --git a/internal/wire/parser_lazy_loader_test.go b/internal/wire/parser_lazy_loader_test.go index 31838ea..86b49da 100644 --- a/internal/wire/parser_lazy_loader_test.go +++ b/internal/wire/parser_lazy_loader_test.go @@ -47,7 +47,7 @@ func TestLazyLoaderParseFileFor(t *testing.T) { "", }, "\n") - parse := ll.parseFileFor(pkgPath) + parse := ll.parseFileFor(pkgPath, &parseFileStats{}) file, err := parse(fset, primary, []byte(src)) if err != nil { t.Fatalf("parse primary: %v", err) @@ -73,6 +73,59 @@ func TestLazyLoaderParseFileFor(t *testing.T) { } } +func TestLazyLoaderParseFileForCachesDependencyFiles(t *testing.T) { + t.Helper() + fset := token.NewFileSet() + pkgPath := "example.com/pkg" + root := t.TempDir() + primary := filepath.Join(root, "primary.go") + secondary := filepath.Join(root, "secondary.go") + session := &incrementalSession{ + fset: fset, + parsedDeps: make(map[string]cachedParsedFile), + } + ll := &lazyLoader{ + fset: fset, + baseFiles: map[string]map[string]struct{}{ + pkgPath: {filepath.Clean(primary): {}}, + }, + session: session, + } + src := []byte(strings.Join([]string{ + "package pkg", + "", + "func Foo() {", + "\tprintln(\"hi\")", + "}", + "", + }, "\n")) + + stats1 := &parseFileStats{} + parse1 := ll.parseFileFor(pkgPath, stats1) + file1, err := parse1(fset, secondary, src) + if err != nil { + t.Fatalf("first parse: %v", err) + } + snap1 := stats1.snapshot() + if snap1.cacheHits != 0 || snap1.cacheMisses != 1 { + t.Fatalf("first parse stats = %+v, want 0 hits and 1 miss", snap1) + } + + stats2 := &parseFileStats{} + parse2 := ll.parseFileFor(pkgPath, stats2) + file2, err := parse2(fset, secondary, src) + if err != nil { + t.Fatalf("second parse: %v", err) + } + if file1 != file2 { + t.Fatal("expected cached dependency parse to reuse AST") + } + snap2 := stats2.snapshot() + if snap2.cacheHits != 1 || snap2.cacheMisses != 0 { + t.Fatalf("second parse stats = %+v, want 1 hit and 0 misses", snap2) + } +} + func TestLoadModuleUsesWireinjectTagsForDeps(t *testing.T) { repoRoot := mustRepoRoot(t) root := t.TempDir() diff --git a/internal/wire/time_compat.go b/internal/wire/time_compat.go new file mode 100644 index 0000000..6f0c9c4 --- /dev/null +++ b/internal/wire/time_compat.go @@ -0,0 +1,22 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import "time" + +var ( + timeNow = time.Now + timeSince = time.Since +) diff --git a/internal/wire/timing.go b/internal/wire/timing.go index 376d573..d83754b 100644 --- a/internal/wire/timing.go +++ b/internal/wire/timing.go @@ -16,6 +16,7 @@ package wire import ( "context" + "log" "time" ) @@ -49,3 +50,10 @@ func logTiming(ctx context.Context, label string, start time.Time) { t(label, time.Since(start)) } } + +func debugf(ctx context.Context, format string, args ...interface{}) { + if timing(ctx) == nil { + return + } + log.Printf("timing: "+format, args...) +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index aa3efe3..64202dc 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -25,6 +25,7 @@ import ( "go/token" "go/types" "io/ioutil" + "os" "path/filepath" "sort" "strconv" @@ -53,10 +54,27 @@ type GenerateResult struct { // Commit writes the generated file to disk. func (gen GenerateResult) Commit() error { + _, err := gen.CommitWithStatus() + return err +} + +// CommitWithStatus writes the generated file to disk when the content changed. +// It returns whether the file was written. +func (gen GenerateResult) CommitWithStatus() (bool, error) { if len(gen.Content) == 0 { - return nil + return false, nil + } + current, err := os.ReadFile(gen.OutputPath) + if err == nil && bytes.Equal(current, gen.Content) { + return false, nil } - return ioutil.WriteFile(gen.OutputPath, gen.Content, 0666) + if err != nil && !os.IsNotExist(err) { + return false, err + } + if err := ioutil.WriteFile(gen.OutputPath, gen.Content, 0666); err != nil { + return false, err + } + return true, nil } // GenerateOptions holds options for Generate. @@ -83,6 +101,20 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } + var preloadState *incrementalPreloadState + bypassIncrementalManifest := false + if IncrementalEnabled(ctx, env) { + debugf(ctx, "incremental=enabled") + preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { + return cached, nil + } + if generated, ok, bypass, errs := tryIncrementalLocalFastPath(ctx, wd, env, patterns, opts, preloadState); ok || len(errs) > 0 { + return generated, errs + } else if bypass { + bypassIncrementalManifest = true + } + } if cached, ok := readManifestResults(wd, env, patterns, opts); ok { return cached, nil } @@ -92,16 +124,69 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if len(errs) > 0 { return nil, errs } + if !bypassIncrementalManifest { + if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { + warmPackageOutputCache(pkgs, opts, cached) + return cached, nil + } + } else { + debugf(ctx, "incremental.manifest bypass reason=fastpath_error") + ctx = withBypassPackageCache(ctx) + } generated := make([]GenerateResult, len(pkgs)) for i, pkg := range pkgs { generated[i] = generateForPackage(ctx, pkg, loader, opts) } if allGeneratedOK(generated) { + if IncrementalEnabled(ctx, env) { + writeIncrementalPackageSummaries(loader, pkgs) + } writeManifest(wd, env, patterns, opts, pkgs) + writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } return generated, nil } +func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { + if len(pkgs) == 0 || len(generated) == 0 { + return + } + byPkg := make(map[string][]byte, len(generated)) + for _, gen := range generated { + if len(gen.Content) == 0 { + continue + } + byPkg[gen.PkgPath] = gen.Content + } + for _, pkg := range pkgs { + content := byPkg[pkg.PkgPath] + if len(content) == 0 { + continue + } + key, err := cacheKeyForPackage(pkg, opts) + if err != nil || key == "" { + continue + } + writeCache(key, content) + } +} + +func incrementalManifestPackages(pkgs []*packages.Package, loader *lazyLoader) []*packages.Package { + if loader == nil || len(loader.loaded) == 0 { + return pkgs + } + out := make([]*packages.Package, 0, len(loader.loaded)) + for _, pkg := range loader.loaded { + if pkg != nil { + out = append(out, pkg) + } + } + if len(out) == 0 { + return pkgs + } + return out +} + // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 14080df..cb167aa 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -111,6 +111,7 @@ func TestWire(t *testing.T) { t.Log(e.Error()) gotErrStrings[i] = scrubError(gopath, e.Error()) } + gotErrStrings = filterLegacyCompilerErrors(gotErrStrings) if !test.wantWireError { t.Fatal("Did not expect errors. To -record an error, create want/wire_errs.txt.") } @@ -191,6 +192,33 @@ func TestGenerateResultCommit(t *testing.T) { } } +func TestGenerateResultCommitWithStatus(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "wire_gen.go") + gen := GenerateResult{ + OutputPath: path, + Content: []byte("package p\n"), + } + + wrote, err := gen.CommitWithStatus() + if err != nil { + t.Fatalf("first CommitWithStatus failed: %v", err) + } + if !wrote { + t.Fatal("expected first CommitWithStatus call to write") + } + + wrote, err = gen.CommitWithStatus() + if err != nil { + t.Fatalf("second CommitWithStatus failed: %v", err) + } + if wrote { + t.Fatal("expected second CommitWithStatus call to report unchanged") + } +} + func TestZeroValue(t *testing.T) { t.Parallel() @@ -521,6 +549,28 @@ func scrubLineColumn(s string) (replacement string, n int) { return ":x:y", n } +func filterLegacyCompilerErrors(errs []string) []string { + hasCanonicalPath := false + for _, err := range errs { + if strings.HasPrefix(err, "example.com/") { + hasCanonicalPath = true + break + } + } + if !hasCanonicalPath { + return errs + } + + filtered := errs[:0] + for _, err := range errs { + if strings.HasPrefix(err, "-: # ") { + continue + } + filtered = append(filtered, err) + } + return filtered +} + type testCase struct { name string pkg string From e3f07cb43397fbd5e25cb45c334300fc764126ef Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 23:35:20 -0500 Subject: [PATCH 02/82] feat(incremental): reuse unchanged local packages in fast path and harden fallback behavior --- internal/wire/incremental_bench_test.go | 117 +++++-- internal/wire/incremental_manifest.go | 4 - internal/wire/incremental_summary.go | 11 +- internal/wire/loader_test.go | 142 ++++++++ internal/wire/local_fastpath.go | 371 +++++++++++++++++++-- internal/wire/parse.go | 20 +- internal/wire/summary_provider_resolver.go | 223 +++++++++++++ 7 files changed, 826 insertions(+), 62 deletions(-) create mode 100644 internal/wire/summary_provider_resolver.go diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go index b981d23..911a8c7 100644 --- a/internal/wire/incremental_bench_test.go +++ b/internal/wire/incremental_bench_test.go @@ -5,15 +5,15 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" - "text/tabwriter" "testing" "time" ) const ( largeBenchmarkTestPackageCount = 24 - largeBenchmarkHelperCount = 12 + largeBenchmarkHelperCount = 12 ) var largeBenchmarkSizes = []int{10, 100, 1000} @@ -106,36 +106,42 @@ func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) rows = append(rows, largeRepoBenchmarkRow{ - packageCount: packageCount, - coldNormal: coldNormal, - coldIncremental: coldIncremental, - normal: normal, - incremental: incremental, - knownToggle: knownToggle, + packageCount: packageCount, + coldNormal: coldNormal, + coldIncremental: coldIncremental, + normal: normal, + incremental: incremental, + knownToggle: knownToggle, }) } - var out strings.Builder - tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, "size\tcold normal\tcold incr\tcold delta\tcold x\tshape normal\tshape incr\tshape delta\tshape x\tknown toggle") + table := [][]string{{ + "repo size", + "cold old", + "cold new", + "cold delta", + "shape old", + "shape new", + "shape delta", + "known toggle", + "cold speedup", + "shape speedup", + }} for _, row := range rows { - fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%.2fx\t%s\t%s\t%s\t%.2fx\t%s\n", - row.packageCount, + table = append(table, []string{ + strconv.Itoa(row.packageCount), formatBenchmarkDuration(row.coldNormal), formatBenchmarkDuration(row.coldIncremental), formatPercentImprovement(row.coldNormal, row.coldIncremental), - speedupRatio(row.coldNormal, row.coldIncremental), formatBenchmarkDuration(row.normal), formatBenchmarkDuration(row.incremental), formatPercentImprovement(row.normal, row.incremental), - speedupRatio(row.normal, row.incremental), formatBenchmarkDuration(row.knownToggle), - ) - } - if err := tw.Flush(); err != nil { - t.Fatalf("flush benchmark table: %v", err) + fmt.Sprintf("%.2fx", speedupRatio(row.coldNormal, row.coldIncremental)), + fmt.Sprintf("%.2fx", speedupRatio(row.normal, row.incremental)), + }) } - fmt.Print(out.String()) + fmt.Print(renderASCIITable(table)) } func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { @@ -151,27 +157,35 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { }) repoRoot := benchmarkRepoRoot(t) - var out strings.Builder - tw := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, "size\tnormal total\tbase load\tlazy load\tincr total\tfast load\tfast generate\tspeedup") + rows := [][]string{{ + "repo size", + "old total", + "old base load", + "old typed load", + "new total", + "new local load", + "new cached sets", + "new injector solve", + "new generate", + "speedup", + }} for _, packageCount := range largeBenchmarkSizes { normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) - fmt.Fprintf(tw, "%d\t%s\t%s\t%s\t%s\t%s\t%s\t%.2fx\n", - packageCount, + rows = append(rows, []string{ + strconv.Itoa(packageCount), formatBenchmarkDuration(normal.total), formatBenchmarkDuration(normal.label("load.packages.base.load")), formatBenchmarkDuration(normal.label("load.packages.lazy.load")), formatBenchmarkDuration(incremental.total), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.summary_resolve")), + formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), - speedupRatio(normal.total, incremental.total), - ) - } - if err := tw.Flush(); err != nil { - t.Fatalf("flush breakdown table: %v", err) + fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), + }) } - fmt.Print(out.String()) + fmt.Print(renderASCIITable(rows)) } func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { @@ -652,3 +666,44 @@ func writeBenchmarkFile(tb testing.TB, path string, content string) { tb.Fatalf("WriteFile failed: %v", err) } } + +func renderASCIITable(rows [][]string) string { + if len(rows) == 0 { + return "" + } + widths := make([]int, len(rows[0])) + for _, row := range rows { + for i, cell := range row { + if len(cell) > widths[i] { + widths[i] = len(cell) + } + } + } + var b strings.Builder + border := func() { + b.WriteByte('+') + for _, width := range widths { + b.WriteString(strings.Repeat("-", width+2)) + b.WriteByte('+') + } + b.WriteByte('\n') + } + writeRow := func(row []string) { + b.WriteByte('|') + for i, cell := range row { + b.WriteByte(' ') + b.WriteString(cell) + b.WriteString(strings.Repeat(" ", widths[i]-len(cell)+1)) + b.WriteByte('|') + } + b.WriteByte('\n') + } + border() + writeRow(rows[0]) + border() + for _, row := range rows[1:] { + writeRow(row) + } + border() + return b.String() +} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index ae36c77..11d250f 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -250,12 +250,8 @@ func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile } func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { - moduleRoot := findModuleRoot(wd) out := make([]externalPackageExport, 0) for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) == "local" { - continue - } if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { continue } diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go index faaa9b8..2930b37 100644 --- a/internal/wire/incremental_summary.go +++ b/internal/wire/incremental_summary.go @@ -148,6 +148,10 @@ func writeIncrementalPackageSummary(key string, summary *packageSummary) { } func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { + writeIncrementalPackageSummariesWithSummary(loader, pkgs, nil, nil) +} + +func writeIncrementalPackageSummariesWithSummary(loader *lazyLoader, pkgs []*packages.Package, summary *summaryProviderResolver, only map[string]struct{}) { if loader == nil || len(pkgs) == 0 { return } @@ -162,11 +166,16 @@ func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Packa for _, pkg := range all { allPkgs = append(allPkgs, pkg) } - oc := newObjectCache(allPkgs, loader) + oc := newObjectCacheWithLoader(allPkgs, loader, nil, summary) for _, pkg := range all { if classifyPackageLocation(moduleRoot, pkg) != "local" { continue } + if len(only) > 0 { + if _, ok := only[pkg.PkgPath]; !ok { + continue + } + } if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { continue } diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 2f41c8d..6a26d8e 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -794,6 +794,148 @@ func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { } } +func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"example.com/app/router\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *router.Routes {", + "\twire.Build(dep.Set, router.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewController(msg string) *Controller {", + "\treturn &Controller{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewController)", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct { Controller *dep.Controller }", + "", + "func ProvideRoutes(controller *dep.Controller) *Routes {", + "\treturn &Routes{Controller: controller}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(ProvideRoutes)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string; Count int }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewCount() int { return 7 }", + "", + "func NewController(msg string, count int) *Controller {", + "\treturn &Controller{Message: msg, Count: count}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewCount, NewController)", + "", + }, "\n")) + + var incrementalLabels []string + incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + incrementalLabels = append(incrementalLabels, label) + }) + incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) + } + if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { + t.Fatalf("expected incremental Generate to use local fast path, labels=%v", incrementalLabels) + } + + normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate results: %+v", normalGens) + } + if string(incrementalGens[0].Content) != string(normalGens[0].Content) { + t.Fatal("incremental output differs from normal Generate output when unchanged package depends on changed package") + } +} + func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 466dcc2..4ef1f8f 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -59,7 +59,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p } fastPathStart := time.Now() - loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], state.currentLocal, state.manifest.ExternalPkgs) + loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], changed, state.currentLocal, state.manifest.ExternalPkgs) if err != nil { debugf(ctx, "incremental.local_fastpath miss reason=%v", err) if shouldBypassIncrementalManifestAfterFastPathError(err) { @@ -69,7 +69,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p } logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) - generated, errs := generateFromTypedPackages(ctx, loaded.root, loaded.allPackages, opts) + generated, errs := generateFromTypedPackages(ctx, loaded, opts) logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) if len(errs) > 0 { return nil, true, true, errs @@ -91,8 +91,12 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p for path, pkg := range loaded.byPath { loader.loaded[path] = pkg } + changedSet := make(map[string]struct{}, len(snapshot.changed)) + for _, path := range snapshot.changed { + changedSet[path] = struct{}{} + } writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeIncrementalPackageSummaries(loader, loaded.allPackages) + writeIncrementalPackageSummariesWithSummary(loader, loaded.allPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) @@ -105,6 +109,9 @@ func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { return false } msg := err.Error() + if strings.Contains(msg, "missing external export data for ") { + return false + } return strings.Contains(msg, "type-check failed for ") } @@ -205,6 +212,7 @@ type localFastPathLoaded struct { allPackages []*packages.Package byPath map[string]*packages.Package fingerprints map[string]*packageFingerprint + loader *localFastPathLoader } type localFastPathLoader struct { @@ -212,14 +220,19 @@ type localFastPathLoader struct { wd string tags string fset *token.FileSet + modulePrefix string rootPkgPath string + changedPkgs map[string]struct{} + sourcePkgs map[string]struct{} + summaries map[string]*packageSummary meta map[string]*packageFingerprint pkgs map[string]*packages.Package + imported map[string]*types.Package externalMeta map[string]externalPackageExport externalImp types.Importer } -func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { +func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { meta := fingerprintsFromSlice(current) if len(meta) == 0 { return nil, fmt.Errorf("no local fingerprints") @@ -239,11 +252,38 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r wd: wd, tags: tags, fset: token.NewFileSet(), + modulePrefix: moduleImportPrefix(meta), rootPkgPath: rootPkgPath, + changedPkgs: make(map[string]struct{}, len(changed)), + sourcePkgs: make(map[string]struct{}), + summaries: make(map[string]*packageSummary), meta: meta, pkgs: make(map[string]*packages.Package, len(meta)), + imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), externalMeta: externalMeta, } + for _, path := range changed { + loader.changedPkgs[path] = struct{}{} + } + loader.markSourceClosure() + candidates := make(map[string]*packageSummary) + for path, fp := range meta { + if path == rootPkgPath { + continue + } + if _, changed := loader.changedPkgs[path]; changed { + continue + } + if _, ok := externalMeta[path]; !ok { + continue + } + summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) + if !ok || summary == nil || summary.ShapeHash != fp.ShapeHash { + continue + } + candidates[path] = summary + } + loader.summaries = filterSupportedPackageSummaries(candidates) loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) root, err := loader.load(rootPkgPath) if err != nil { @@ -260,6 +300,7 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r allPackages: all, byPath: loader.pkgs, fingerprints: loader.meta, + loader: loader, }, nil } @@ -275,10 +316,13 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { if len(files) == 0 { return nil, fmt.Errorf("package %s has no files", pkgPath) } - mode := parser.ParseComments | parser.SkipObjectResolution + mode := parser.SkipObjectResolution + if pkgPath == l.rootPkgPath { + mode |= parser.ParseComments + } syntax := make([]*ast.File, 0, len(files)) for _, name := range files { - file, err := parser.ParseFile(l.fset, name, nil, mode) + file, err := l.parseFileForFastPath(name, mode, pkgPath) if err != nil { return nil, err } @@ -289,15 +333,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { } pkgName := syntax[0].Name.Name - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - Scopes: make(map[ast.Node]*types.Scope), - Instances: make(map[*ast.Ident]types.Instance), - } + info := newFastPathTypesInfo(pkgPath == l.rootPkgPath) pkg := &packages.Package{ Fset: l.fset, Name: pkgName, @@ -314,7 +350,8 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { Importer: importerFunc(func(path string) (*types.Package, error) { return l.importPackage(path) }), - Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: l.shouldIgnoreFuncBodies(pkgPath), + Sizes: types.SizesFor("gc", runtime.GOARCH), Error: func(err error) { pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, @@ -322,6 +359,10 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) if checkedPkg != nil { pkg.Types = checkedPkg + l.imported[pkgPath] = checkedPkg + } + if l.shouldRetryWithoutBodyStripping(pkgPath, pkg.Errors) { + return l.reloadWithoutBodyStripping(pkgPath, files, mode, pkg) } if err != nil && len(pkg.Errors) == 0 { return nil, err @@ -347,7 +388,71 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { return pkg, nil } +func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode, pkgPath string) (*ast.File, error) { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + if l.shouldStripFunctionBodies(pkgPath) { + stripFunctionBodies(file) + pruneImportsWithoutTopLevelUse(file) + } + return file, nil +} + +func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { + syntax := make([]*ast.File, 0, len(files)) + for _, name := range files { + file, err := parser.ParseFile(l.fset, name, nil, mode) + if err != nil { + return nil, err + } + syntax = append(syntax, file) + } + pkg.Syntax = syntax + pkg.Errors = nil + pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) + conf := &types.Config{ + Importer: importerFunc(func(path string) (*types.Package, error) { + return l.importPackage(path) + }), + IgnoreFuncBodies: false, + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) + }, + } + checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) + if checkedPkg != nil { + pkg.Types = checkedPkg + l.imported[pkgPath] = checkedPkg + } + if err != nil && len(pkg.Errors) == 0 { + return nil, err + } + if len(pkg.Errors) > 0 { + return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) + } + return pkg, nil +} + +func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, errs []packages.Error) bool { + if !l.shouldStripFunctionBodies(pkgPath) || len(errs) == 0 { + return false + } + for _, pkgErr := range errs { + msg := pkgErr.Msg + if strings.Contains(msg, "missing function body") || strings.Contains(msg, "func init must have a body") { + return true + } + } + return false +} + func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { + if l.shouldImportFromExport(path) { + return l.importExportPackage(path) + } if l.meta[path] != nil { pkg, err := l.load(path) if err != nil { @@ -358,17 +463,132 @@ func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) if l.externalImp == nil { return nil, fmt.Errorf("missing external importer") } - return l.externalImp.Import(path) + return l.importExportPackage(path) } func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { meta, ok := l.externalMeta[path] if !ok || meta.ExportFile == "" { - return nil, fmt.Errorf("missing export data for %s", path) + if l.meta[path] != nil || l.isLikelyLocalImport(path) { + return nil, fmt.Errorf("missing local export data for %s", path) + } + return nil, fmt.Errorf("missing external export data for %s", path) } return os.Open(meta.ExportFile) } +func (l *localFastPathLoader) isLikelyLocalImport(path string) bool { + if l == nil || l.modulePrefix == "" { + return false + } + return path == l.modulePrefix || strings.HasPrefix(path, l.modulePrefix+"/") +} + +func moduleImportPrefix(meta map[string]*packageFingerprint) string { + if len(meta) == 0 { + return "" + } + paths := make([]string, 0, len(meta)) + for path := range meta { + paths = append(paths, path) + } + sort.Strings(paths) + prefix := strings.Split(paths[0], "/") + for _, path := range paths[1:] { + parts := strings.Split(path, "/") + n := len(prefix) + if len(parts) < n { + n = len(parts) + } + i := 0 + for i < n && prefix[i] == parts[i] { + i++ + } + prefix = prefix[:i] + if len(prefix) == 0 { + return "" + } + } + return strings.Join(prefix, "/") +} + +func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, error) { + if l == nil { + return nil, fmt.Errorf("missing local fast path loader") + } + if pkg := l.imported[path]; pkg != nil { + return pkg, nil + } + if l.externalImp == nil { + return nil, fmt.Errorf("missing external importer") + } + pkg, err := l.externalImp.Import(path) + if err != nil { + return nil, err + } + l.imported[path] = pkg + return pkg, nil +} + +func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { + if l == nil { + return false + } + if _, source := l.sourcePkgs[pkgPath]; source { + return false + } + _, ok := l.summaries[pkgPath] + return ok +} + +func (l *localFastPathLoader) markSourceClosure() { + if l == nil { + return + } + reverse := make(map[string][]string) + for pkgPath, fp := range l.meta { + if fp == nil { + continue + } + for _, imp := range fp.LocalImports { + reverse[imp] = append(reverse[imp], pkgPath) + } + } + queue := make([]string, 0, len(l.changedPkgs)+1) + queue = append(queue, l.rootPkgPath) + for pkgPath := range l.changedPkgs { + queue = append(queue, pkgPath) + } + for len(queue) > 0 { + pkgPath := queue[0] + queue = queue[1:] + if _, seen := l.sourcePkgs[pkgPath]; seen { + continue + } + l.sourcePkgs[pkgPath] = struct{}{} + for _, importer := range reverse[pkgPath] { + if _, seen := l.sourcePkgs[importer]; !seen { + queue = append(queue, importer) + } + } + } +} + +func (l *localFastPathLoader) shouldStripFunctionBodies(pkgPath string) bool { + if l == nil { + return false + } + if pkgPath == l.rootPkgPath { + return false + } + _, changed := l.changedPkgs[pkgPath] + return !changed +} + +func (l *localFastPathLoader) shouldIgnoreFuncBodies(pkgPath string) bool { + return l.shouldStripFunctionBodies(pkgPath) +} + type importerFunc func(string) (*types.Package, error) func (fn importerFunc) Import(path string) (*types.Package, error) { @@ -395,7 +615,105 @@ func packageImportPaths(files []*ast.File) []string { return out } -func generateFromTypedPackages(ctx context.Context, root *packages.Package, allPkgs []*packages.Package, opts *GenerateOptions) ([]GenerateResult, []error) { +func newFastPathTypesInfo(full bool) *types.Info { + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + } + if !full { + return info + } + info.Implicits = make(map[ast.Node]types.Object) + info.Selections = make(map[*ast.SelectorExpr]*types.Selection) + info.Scopes = make(map[ast.Node]*types.Scope) + info.Instances = make(map[*ast.Ident]types.Instance) + return info +} + +func pruneImportsWithoutTopLevelUse(file *ast.File) { + if file == nil || len(file.Imports) == 0 { + return + } + used := usedImportNames(file) + filtered := file.Imports[:0] + for _, spec := range file.Imports { + if spec == nil || spec.Path == nil { + continue + } + name := importName(spec) + if name == "_" || name == "." { + filtered = append(filtered, spec) + continue + } + if _, ok := used[name]; ok { + filtered = append(filtered, spec) + } + } + file.Imports = filtered + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + continue + } + specs := gen.Specs[:0] + for _, spec := range gen.Specs { + importSpec, ok := spec.(*ast.ImportSpec) + if !ok || importSpec.Path == nil { + continue + } + name := importName(importSpec) + if name == "_" || name == "." { + specs = append(specs, spec) + continue + } + if _, ok := used[name]; ok { + specs = append(specs, spec) + } + } + gen.Specs = specs + } +} + +func usedImportNames(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + ast.Inspect(file, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + return used +} + +func importName(spec *ast.ImportSpec) string { + if spec == nil || spec.Path == nil { + return "" + } + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + return "" + } + if slash := strings.LastIndex(path, "/"); slash >= 0 { + path = path[slash+1:] + } + return path +} + +func generateFromTypedPackages(ctx context.Context, loaded *localFastPathLoaded, opts *GenerateOptions) ([]GenerateResult, []error) { + if loaded == nil { + return nil, []error{fmt.Errorf("missing loaded packages")} + } + root := loaded.root if root == nil { return nil, []error{fmt.Errorf("missing root package")} } @@ -412,7 +730,11 @@ func generateFromTypedPackages(ctx context.Context, root *packages.Package, allP } res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - oc := newObjectCache(allPkgs, nil) + var summary *summaryProviderResolver + if loaded.loader != nil { + summary = newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage) + } + oc := newObjectCacheWithLoader(loaded.allPackages, nil, nil, summary) g := newGen(root) injectorStart := time.Now() injectorFiles, errs := generateInjectors(oc, g, root) @@ -447,9 +769,12 @@ func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd s if snapshot == nil { return } - for _, fp := range snapshotPackageFingerprints(snapshot) { - fp := fp - writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), &fp) + for _, path := range snapshot.changed { + fp := snapshot.fingerprints[path] + if fp == nil { + continue + } + writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), fp) } } diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2f038a9..73f218d 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -471,6 +471,7 @@ type objectCache struct { objects map[objRef]objCacheEntry hasher typeutil.Hasher loader *lazyLoader + summary *summaryProviderResolver } type objRef struct { @@ -484,6 +485,10 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { + return newObjectCacheWithLoader(pkgs, loader, nil, nil) +} + +func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *localFastPathLoader, summary *summaryProviderResolver) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } @@ -493,6 +498,7 @@ func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), loader: loader, + summary: summary, } if oc.fset == nil && loader != nil { oc.fset = loader.fset @@ -557,9 +563,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { if ent, cached := oc.objects[ref]; cached { return ent.val, append([]error(nil), ent.errs...) } - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } defer func() { oc.objects[ref] = objCacheEntry{ val: val, @@ -568,6 +571,14 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { }() switch obj := obj.(type) { case *types.Var: + if isProviderSetType(obj.Type()) && oc.summary != nil { + if pset, ok, summaryErrs := oc.summary.Resolve(obj.Pkg().Path(), obj.Name()); ok { + return pset, summaryErrs + } + } + if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { + return nil, errs + } spec := oc.varDecl(obj) if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} @@ -583,6 +594,9 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Func: return processFuncProvider(oc.fset, obj) default: + if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { + return nil, errs + } return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } } diff --git a/internal/wire/summary_provider_resolver.go b/internal/wire/summary_provider_resolver.go new file mode 100644 index 0000000..c93e0c5 --- /dev/null +++ b/internal/wire/summary_provider_resolver.go @@ -0,0 +1,223 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "context" + "fmt" + "go/token" + "go/types" + "time" + + "golang.org/x/tools/go/types/typeutil" +) + +type summaryProviderResolver struct { + ctx context.Context + fset *token.FileSet + summaries map[string]*packageSummary + importPackage func(string) (*types.Package, error) + cache map[providerSetRefSummary]*ProviderSet + resolving map[providerSetRefSummary]struct{} + supported map[string]bool +} + +func newSummaryProviderResolver(ctx context.Context, summaries map[string]*packageSummary, importPackage func(string) (*types.Package, error)) *summaryProviderResolver { + if len(summaries) == 0 || importPackage == nil { + return nil + } + r := &summaryProviderResolver{ + ctx: ctx, + fset: token.NewFileSet(), + summaries: make(map[string]*packageSummary, len(summaries)), + importPackage: importPackage, + cache: make(map[providerSetRefSummary]*ProviderSet), + resolving: make(map[providerSetRefSummary]struct{}), + supported: make(map[string]bool, len(summaries)), + } + for pkgPath, summary := range summaries { + if summary == nil { + continue + } + r.summaries[pkgPath] = summary + } + for pkgPath := range r.summaries { + r.supported[pkgPath] = r.packageSupported(pkgPath, make(map[string]struct{})) + } + return r +} + +func filterSupportedPackageSummaries(summaries map[string]*packageSummary) map[string]*packageSummary { + if len(summaries) == 0 { + return nil + } + resolver := &summaryProviderResolver{ + summaries: summaries, + supported: make(map[string]bool, len(summaries)), + } + out := make(map[string]*packageSummary) + for pkgPath, summary := range summaries { + if summary == nil { + continue + } + if resolver.packageSupported(pkgPath, make(map[string]struct{})) { + out[pkgPath] = summary + } + } + return out +} + +func (r *summaryProviderResolver) Resolve(pkgPath string, varName string) (*ProviderSet, bool, []error) { + if r == nil || !r.supported[pkgPath] { + return nil, false, nil + } + start := time.Now() + set, err := r.resolve(providerSetRefSummary{PkgPath: pkgPath, VarName: varName}) + logTiming(r.ctx, "incremental.local_fastpath.summary_resolve", start) + if err != nil { + return nil, true, []error{err} + } + return set, true, nil +} + +func (r *summaryProviderResolver) resolve(ref providerSetRefSummary) (*ProviderSet, error) { + if set := r.cache[ref]; set != nil { + return set, nil + } + if _, ok := r.resolving[ref]; ok { + return nil, fmt.Errorf("summary provider set cycle for %s.%s", ref.PkgPath, ref.VarName) + } + summary := r.summaries[ref.PkgPath] + if summary == nil { + return nil, fmt.Errorf("missing package summary for %s", ref.PkgPath) + } + setSummary, ok := r.findProviderSet(summary, ref.VarName) + if !ok { + return nil, fmt.Errorf("missing provider set summary for %s.%s", ref.PkgPath, ref.VarName) + } + r.resolving[ref] = struct{}{} + defer delete(r.resolving, ref) + + pkg, err := r.importPackage(ref.PkgPath) + if err != nil { + return nil, err + } + set := &ProviderSet{ + PkgPath: ref.PkgPath, + VarName: ref.VarName, + } + for _, provider := range setSummary.Providers { + resolved, err := r.resolveProvider(pkg, provider) + if err != nil { + return nil, err + } + set.Providers = append(set.Providers, resolved) + } + for _, imported := range setSummary.Imports { + child, err := r.resolve(imported) + if err != nil { + return nil, err + } + set.Imports = append(set.Imports, child) + } + hasher := typeutil.MakeHasher() + providerMap, srcMap, errs := buildProviderMap(r.fset, hasher, set) + if len(errs) > 0 { + return nil, errs[0] + } + if errs := verifyAcyclic(providerMap, hasher); len(errs) > 0 { + return nil, errs[0] + } + set.providerMap = providerMap + set.srcMap = srcMap + r.cache[ref] = set + return set, nil +} + +func (r *summaryProviderResolver) resolveProvider(pkg *types.Package, summary providerSummary) (*Provider, error) { + if summary.IsStruct || len(summary.Out) == 0 { + return nil, fmt.Errorf("unsupported summary provider %s.%s", summary.PkgPath, summary.Name) + } + if pkg == nil || pkg.Path() != summary.PkgPath { + var err error + pkg, err = r.importPackage(summary.PkgPath) + if err != nil { + return nil, err + } + } + obj := pkg.Scope().Lookup(summary.Name) + fn, ok := obj.(*types.Func) + if !ok { + return nil, fmt.Errorf("summary provider %s.%s missing function", summary.PkgPath, summary.Name) + } + provider, errs := processFuncProvider(r.fset, fn) + if len(errs) > 0 { + return nil, errs[0] + } + return provider, nil +} + +func (r *summaryProviderResolver) findProviderSet(summary *packageSummary, varName string) (providerSetSummary, bool) { + if summary == nil { + return providerSetSummary{}, false + } + for _, set := range summary.ProviderSets { + if set.VarName == varName { + return set, true + } + } + return providerSetSummary{}, false +} + +func (r *summaryProviderResolver) packageSupported(pkgPath string, visiting map[string]struct{}) bool { + if ok, seen := r.supported[pkgPath]; seen { + return ok + } + if _, seen := visiting[pkgPath]; seen { + return false + } + summary := r.summaries[pkgPath] + if summary == nil { + return false + } + visiting[pkgPath] = struct{}{} + defer delete(visiting, pkgPath) + for _, set := range summary.ProviderSets { + if !providerSetSummarySupported(set) { + return false + } + for _, imported := range set.Imports { + if _, ok := r.summaries[imported.PkgPath]; !ok { + return false + } + if !r.packageSupported(imported.PkgPath, visiting) { + return false + } + } + } + return true +} + +func providerSetSummarySupported(summary providerSetSummary) bool { + if len(summary.Bindings) > 0 || len(summary.Values) > 0 || len(summary.Fields) > 0 || len(summary.InputTypes) > 0 { + return false + } + for _, provider := range summary.Providers { + if provider.IsStruct { + return false + } + } + return true +} From 2eb540098f701a041e7eff360bf2f72c321873f7 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Fri, 13 Mar 2026 23:47:20 -0500 Subject: [PATCH 03/82] perf(incremental): trim cold bootstrap work and keep warm shape changes fast --- internal/wire/incremental.go | 20 +++++ internal/wire/incremental_fingerprint.go | 99 ++++++++++++++++++++++- internal/wire/incremental_manifest.go | 14 +++- internal/wire/incremental_summary_test.go | 10 ++- internal/wire/loader_test.go | 37 +++++++++ internal/wire/parse.go | 7 +- internal/wire/wire.go | 37 ++++++++- 7 files changed, 213 insertions(+), 11 deletions(-) diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go index 0bc334c..007027b 100644 --- a/internal/wire/incremental.go +++ b/internal/wire/incremental.go @@ -23,6 +23,7 @@ import ( const IncrementalEnvVar = "WIRE_INCREMENTAL" type incrementalKey struct{} +type incrementalColdBootstrapKey struct{} // WithIncremental overrides incremental-mode resolution for the provided // context. This takes precedence over the environment variable. @@ -33,6 +34,13 @@ func WithIncremental(ctx context.Context, enabled bool) context.Context { return context.WithValue(ctx, incrementalKey{}, enabled) } +func withIncrementalColdBootstrap(ctx context.Context, enabled bool) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, incrementalColdBootstrapKey{}, enabled) +} + // IncrementalEnabled reports whether incremental mode is enabled for the // current operation. A context override takes precedence over env. func IncrementalEnabled(ctx context.Context, env []string) bool { @@ -54,6 +62,18 @@ func IncrementalEnabled(ctx context.Context, env []string) bool { return enabled } +func incrementalColdBootstrapEnabled(ctx context.Context) bool { + if ctx == nil { + return false + } + if v := ctx.Value(incrementalColdBootstrapKey{}); v != nil { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + func lookupEnv(env []string, key string) (string, bool) { prefix := key + "=" for i := len(env) - 1; i >= 0; i-- { diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 886d07f..46485f7 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -124,6 +124,52 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac return snapshot } +func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { + all := collectAllPackages(pkgs) + moduleRoot := findModuleRoot(wd) + snapshot := &incrementalFingerprintSnapshot{ + fingerprints: make(map[string]*packageFingerprint), + } + for _, pkg := range all { + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + files := packageFingerprintFiles(pkg) + if len(files) == 0 { + continue + } + sort.Strings(files) + metaFiles, err := buildCacheFiles(files) + if err != nil { + continue + } + shapeHash, err := packageShapeHashFromSyntax(pkg, files) + if err != nil { + continue + } + localImports := make([]string, 0, len(pkg.Imports)) + for _, imp := range pkg.Imports { + if classifyPackageLocation(moduleRoot, imp) == "local" { + localImports = append(localImports, imp.PkgPath) + } + } + sort.Strings(localImports) + snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ + Version: incrementalFingerprintVersion, + WD: filepath.Clean(wd), + Tags: tags, + PkgPath: pkg.PkgPath, + Files: metaFiles, + ShapeHash: shapeHash, + LocalImports: localImports, + } + } + if len(snapshot.fingerprints) == 0 { + return nil + } + return snapshot +} + func packageFingerprintFiles(pkg *packages.Package) []string { if pkg == nil { return nil @@ -202,16 +248,63 @@ func packageShapeHash(files []string) (string, error) { if err != nil { return "", err } - stripFunctionBodies(file) - if err := printer.Fprint(&buf, fset, file); err != nil { - return "", err + writeSyntaxShapeHash(&buf, fset, file) + buf.WriteByte(0) + } + sum := sha256.Sum256(buf.Bytes()) + return fmt.Sprintf("%x", sum[:]), nil +} + +func packageShapeHashFromSyntax(pkg *packages.Package, files []string) (string, error) { + if pkg == nil || len(pkg.Syntax) == 0 || pkg.Fset == nil { + return packageShapeHash(files) + } + var buf bytes.Buffer + for _, file := range pkg.Syntax { + if file == nil { + continue } + writeSyntaxShapeHash(&buf, pkg.Fset, file) buf.WriteByte(0) } sum := sha256.Sum256(buf.Bytes()) return fmt.Sprintf("%x", sum[:]), nil } +func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File) { + if file == nil || buf == nil || fset == nil { + return + } + if file.Name != nil { + buf.WriteString("package ") + buf.WriteString(file.Name.Name) + buf.WriteByte('\n') + } + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + writeNodeHash(buf, fset, decl.Recv) + buf.WriteByte(' ') + if decl.Name != nil { + buf.WriteString(decl.Name.Name) + } + buf.WriteByte(' ') + writeNodeHash(buf, fset, decl.Type) + buf.WriteByte('\n') + default: + writeNodeHash(buf, fset, decl) + buf.WriteByte('\n') + } + } +} + +func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { + if buf == nil || fset == nil || node == nil { + return + } + _ = printer.Fprint(buf, fset, node) +} + func stripFunctionBodies(file *ast.File) { if file == nil { return diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index 11d250f..cd88976 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -143,13 +143,21 @@ func readIncrementalManifestResults(ctx context.Context, wd string, env []string } func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { + writeIncrementalManifestWithOptions(wd, env, patterns, opts, pkgs, snapshot, generated, true) +} + +func writeIncrementalManifestWithOptions(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult, includeExternalFiles bool) { if snapshot == nil || len(generated) == 0 { return } externalPkgs := buildExternalPackageExports(wd, pkgs) - externalFiles, err := buildExternalPackageFiles(wd, pkgs) - if err != nil { - return + var externalFiles []cacheFile + if includeExternalFiles { + var err error + externalFiles, err = buildExternalPackageFiles(wd, pkgs) + if err != nil { + return + } } manifest := &incrementalManifest{ Version: incrementalManifestVersion, diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go index efb4028..ae85651 100644 --- a/internal/wire/incremental_summary_test.go +++ b/internal/wire/incremental_summary_test.go @@ -242,6 +242,14 @@ func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { if len(gens) != 1 || len(gens[0].Errs) > 0 { t.Fatalf("unexpected Generate result: %+v", gens) } + pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors while seeding summaries: %v", errs) + } + if _, errs := newObjectCache(pkgs, loader).ensurePackage("example.com/app/app"); len(errs) > 0 { + t.Fatalf("ensurePackage returned errors while seeding summaries: %v", errs) + } + writeIncrementalPackageSummaries(loader, pkgs) writeFile(t, depFile, strings.Join([]string{ "package dep", @@ -264,7 +272,7 @@ func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { "", }, "\n")) - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) + pkgs, loader, errs = load(ctx, root, env, "", []string{"./app"}) if len(errs) > 0 { t.Fatalf("load returned errors: %v", errs) } diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 6a26d8e..6899249 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -794,6 +794,43 @@ func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { } } +func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeLargeBenchmarkModule(t, repoRoot, root, 24) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("cold bootstrap Generate returned errors: %v", errs) + } + + mutateLargeBenchmarkModule(t, root, 12) + + var labels []string + timedCtx := WithTiming(ctx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("shape-change Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected cold bootstrap to seed fast path, labels=%v", labels) + } +} + func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 73f218d..0cb0551 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -403,8 +403,11 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] if len(errs) > 0 { return nil, nil, errs } - fingerprints := analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) - analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) + var fingerprints *incrementalFingerprintSnapshot + if !incrementalColdBootstrapEnabled(ctx) { + fingerprints = analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) + analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) + } baseFiles := collectPackageFiles(pkgs) loader := &lazyLoader{ diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 64202dc..8b617ea 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -103,9 +103,14 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } var preloadState *incrementalPreloadState bypassIncrementalManifest := false + coldBootstrap := false if IncrementalEnabled(ctx, env) { debugf(ctx, "incremental=enabled") preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) + coldBootstrap = preloadState == nil + if coldBootstrap { + ctx = withIncrementalColdBootstrap(ctx, true) + } if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { return cached, nil } @@ -139,14 +144,42 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } if allGeneratedOK(generated) { if IncrementalEnabled(ctx, env) { - writeIncrementalPackageSummaries(loader, pkgs) + if coldBootstrap { + snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) + writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) + if snapshot != nil { + writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) + } + } else { + writeIncrementalPackageSummaries(loader, pkgs) + writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) + } } writeManifest(wd, env, patterns, opts, pkgs) - writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } return generated, nil } +func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { + if len(generated) == 0 { + return nil + } + seen := make(map[string]struct{}, len(generated)) + out := make([]string, 0, len(generated)) + for _, gen := range generated { + if gen.PkgPath == "" { + continue + } + if _, ok := seen[gen.PkgPath]; ok { + continue + } + seen[gen.PkgPath] = struct{}{} + out = append(out, gen.PkgPath) + } + sort.Strings(out) + return out +} + func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { if len(pkgs) == 0 || len(generated) == 0 { return From ad2d561df609fbda6a5780547d21b3abf55e800c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 00:11:01 -0500 Subject: [PATCH 04/82] perf(incremental): load deps conditionally --- internal/wire/parse.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0cb0551..a7a1a02 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -377,7 +377,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] } baseCfg := &packages.Config{ Context: ctx, - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps, + Mode: baseLoadMode(ctx), Dir: wd, Env: env, BuildFlags: []string{"-tags=wireinject"}, @@ -423,6 +423,14 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] return pkgs, loader, nil } +func baseLoadMode(ctx context.Context) packages.LoadMode { + mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports + if !incrementalColdBootstrapEnabled(ctx) { + mode |= packages.NeedDeps + } + return mode +} + func collectLoadErrors(pkgs []*packages.Package) []error { var errs []error for _, p := range pkgs { From 578b24f7be26cab46528ad6ce71f58266921189a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 00:25:39 -0500 Subject: [PATCH 05/82] chore(incremental): clear session cache --- cmd/wire/cache_cmd.go | 7 +++++-- internal/wire/cache_coverage_test.go | 30 ++++++++++++++++++++++++++++ internal/wire/cache_store.go | 1 + internal/wire/incremental_session.go | 7 +++++++ 4 files changed, 43 insertions(+), 2 deletions(-) diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index e5ceda4..f34d381 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -38,9 +38,9 @@ func (*cacheCmd) Synopsis() string { // Usage returns the help text for the subcommand. func (*cacheCmd) Usage() string { - return `cache [-clear] + return `cache [-clear|clear] - By default, prints the cache directory. With -clear, removes all cache files. + By default, prints the cache directory. With -clear or clear, removes all cache files. ` } @@ -51,6 +51,9 @@ func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { // Execute runs the subcommand. func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + if f.NArg() > 0 && f.Arg(0) == "clear" { + cmd.clear = true + } if cmd.clear { if err := wire.ClearCache(); err != nil { log.Printf("failed to clear cache: %v\n", err) diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go index d65de26..316f30e 100644 --- a/internal/wire/cache_coverage_test.go +++ b/internal/wire/cache_coverage_test.go @@ -166,6 +166,36 @@ func TestCacheStoreReadWrite(t *testing.T) { } } +func TestClearCacheClearsIncrementalSessions(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + tempDir := t.TempDir() + osTempDir = func() string { return tempDir } + + sessionA := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionA == nil { + t.Fatal("expected incremental session") + } + sessionB := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionA != sessionB { + t.Fatal("expected same incremental session before clear") + } + + if err := ClearCache(); err != nil { + t.Fatalf("ClearCache failed: %v", err) + } + + sessionC := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") + if sessionC == nil { + t.Fatal("expected incremental session after clear") + } + if sessionC == sessionA { + t.Fatal("expected ClearCache to drop in-process incremental sessions") + } +} + func TestCacheStoreReadError(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() diff --git a/internal/wire/cache_store.go b/internal/wire/cache_store.go index dce5565..0c959cf 100644 --- a/internal/wire/cache_store.go +++ b/internal/wire/cache_store.go @@ -32,6 +32,7 @@ func CacheDir() string { // ClearCache removes all cached data. func ClearCache() error { + clearIncrementalSessions() return osRemoveAll(cacheDir()) } diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go index fda6605..72b051b 100644 --- a/internal/wire/incremental_session.go +++ b/internal/wire/incremental_session.go @@ -37,6 +37,13 @@ type cachedParsedFile struct { var incrementalSessions sync.Map +func clearIncrementalSessions() { + incrementalSessions.Range(func(key, _ any) bool { + incrementalSessions.Delete(key) + return true + }) +} + func sessionKey(wd string, env []string, tags string) string { var b strings.Builder b.WriteString(filepath.Clean(wd)) From 83806b9e0dbb0d9da99bd0749c50801dbf54449a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sat, 14 Mar 2026 13:33:42 -0500 Subject: [PATCH 06/82] fix(cli): improve wire error coloring and solve error labeling --- cmd/wire/main.go | 72 ++++++- cmd/wire/main_test.go | 114 ++++++++++ internal/wire/incremental_fingerprint.go | 3 + internal/wire/incremental_manifest.go | 9 + internal/wire/loader_test.go | 264 ++++++++++++++++++++++- internal/wire/local_fastpath.go | 20 ++ internal/wire/parser_lazy_loader.go | 23 +- internal/wire/wire.go | 6 + 8 files changed, 495 insertions(+), 16 deletions(-) create mode 100644 cmd/wire/main_test.go diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 3166531..d40e439 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -21,6 +21,7 @@ import ( "context" "flag" "fmt" + "io" "io/ioutil" "log" "os" @@ -37,7 +38,7 @@ import ( var topLevelIncremental optionalBoolFlag const ( - ansiRed = "\033[31m" + ansiRed = "\033[1;31m" ansiReset = "\033[0m" ) @@ -208,7 +209,7 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - msg := err.Error() + msg := formatLoggedError(err) if strings.Contains(msg, "\n") { logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) continue @@ -217,25 +218,78 @@ func logErrors(errs []error) { } } -func logMultilineError(msg string) { - if shouldColorStderr() { - log.Print(ansiRed + msg + ansiReset) - return +func formatLoggedError(err error) string { + if err == nil { + return "" + } + msg := err.Error() + if strings.HasPrefix(msg, "inject ") { + return "solve failed\n" + msg } - log.Print(msg) + if idx := strings.Index(msg, ": inject "); idx >= 0 { + return "solve failed\n" + msg + } + return msg +} + +func logMultilineError(msg string) { + writeErrorLog(os.Stderr, msg) } func shouldColorStderr() bool { - if os.Getenv("NO_COLOR") != "" { + return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) +} + +func shouldColorOutput(isTTY bool, term string) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { return false } - term := os.Getenv("TERM") + if forceColorEnabled() { + return true + } if term == "" || term == "dumb" { return false } + return isTTY +} + +func forceColorEnabled() bool { + return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" +} + +func stderrIsTTY() bool { info, err := os.Stderr.Stat() if err != nil { return false } return (info.Mode() & os.ModeCharDevice) != 0 } + +func writeErrorLog(w io.Writer, msg string) { + line := "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, colorizeLines(line)) + return + } + _, _ = io.WriteString(w, line) +} + +func colorizeLines(s string) string { + if s == "" { + return "" + } + parts := strings.SplitAfter(s, "\n") + var b strings.Builder + for _, part := range parts { + if part == "" { + continue + } + b.WriteString(ansiRed) + b.WriteString(part) + b.WriteString(ansiReset) + } + return b.String() +} diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go new file mode 100644 index 0000000..b172f62 --- /dev/null +++ b/cmd/wire/main_test.go @@ -0,0 +1,114 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestFormatLoggedErrorAddsSolveHeader(t *testing.T) { + err := testError("inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\ninject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorAddsSolveHeaderWithPositionPrefix(t *testing.T) { + err := testError("/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo") + got := formatLoggedError(err) + want := "solve failed\n/tmp/wire.go:12:1: inject InitializeApplication: no provider found for *example.Foo" + if got != want { + t.Fatalf("formatLoggedError() = %q, want %q", got, want) + } +} + +func TestFormatLoggedErrorLeavesNonSolveErrorsUnchanged(t *testing.T) { + err := testError("type-check failed for example.com/app/app") + got := formatLoggedError(err) + if got != err.Error() { + t.Fatalf("formatLoggedError() = %q, want %q", got, err.Error()) + } +} + +func TestShouldColorOutputForceColorOverridesTTYRequirement(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true when FORCE_COLOR is set") + } +} + +func TestShouldColorOutputNoColorWins(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "1") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false when NO_COLOR is set") + } +} + +func TestShouldColorOutputTTYFallback(t *testing.T) { + t.Setenv("FORCE_COLOR", "") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + if !shouldColorOutput(true, "xterm-256color") { + t.Fatal("shouldColorOutput() = false, want true for tty stderr") + } + if shouldColorOutput(false, "xterm-256color") { + t.Fatal("shouldColorOutput() = true, want false for non-tty stderr without force color") + } +} + +func TestWriteErrorLogFormatsWirePrefix(t *testing.T) { + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := "wire: type-check failed for example.com/app/app\n" + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsWholeBlockWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "type-check failed for example.com/app/app") + got := buf.String() + want := ansiRed + "wire: type-check failed for example.com/app/app\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { + t.Setenv("FORCE_COLOR", "1") + t.Setenv("NO_COLOR", "") + t.Setenv("CLICOLOR", "") + t.Setenv("CLICOLOR_FORCE", "") + + var buf bytes.Buffer + writeErrorLog(&buf, "\n first line\n second line") + got := buf.String() + want := ansiRed + "wire: \n" + ansiReset + + ansiRed + " first line\n" + ansiReset + + ansiRed + " second line\n" + ansiReset + if got != want { + t.Fatalf("writeErrorLog() = %q, want %q", got, want) + } +} + +type testError string + +func (e testError) Error() string { return string(e) } diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 46485f7..42c2317 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -54,6 +54,7 @@ type fingerprintStats struct { type incrementalFingerprintSnapshot struct { stats fingerprintStats changed []string + touched []string fingerprints map[string]*packageFingerprint } @@ -106,6 +107,7 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac continue } snapshot.stats.metaMisses++ + snapshot.touched = append(snapshot.touched, pkg.PkgPath) fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) if err != nil { continue @@ -121,6 +123,7 @@ func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Pac snapshot.changed = append(snapshot.changed, pkg.PkgPath) } sort.Strings(snapshot.changed) + sort.Strings(snapshot.touched) return snapshot } diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index cd88976..4a55c19 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -414,6 +414,8 @@ func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packag if firstReason == "" { firstReason = fp.PkgPath + ".shape_mismatch" } + } else if firstReason == "" { + firstReason = fp.PkgPath + ".meta_changed" } } if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { @@ -574,6 +576,13 @@ func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { } } +func removeIncrementalManifestFile(key string) { + if key == "" { + return + } + _ = osRemove(incrementalManifestPath(key)) +} + func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { var buf bytes.Buffer if manifest == nil { diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index 6899249..d74c0f8 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -220,7 +220,7 @@ func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { } } -func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) { +func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() t.Cleanup(func() { restoreCacheHooks(state) }) @@ -340,17 +340,119 @@ func TestGenerateIncrementalManifestSkipsLazyLoadOnBodyOnlyChange(t *testing.T) if len(second) != 1 || len(second[0].Errs) > 0 { t.Fatalf("unexpected second Generate result: %+v", second) } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to hit preload incremental manifest before package load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected second Generate to skip lazy load, labels=%v", secondLabels) + if !containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to re-load packages after body-only change, labels=%v", secondLabels) } if string(first[0].Content) != string(second[0].Content) { t.Fatal("expected body-only change to reuse identical generated output") } } +func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn missing", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid body-only change to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid body-only change to return errors") + } + if !containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected invalid body-only change to bypass preload manifest and load packages, labels=%v", secondLabels) + } + if got := errs[0].Error(); !strings.Contains(got, "undefined: missing") { + t.Fatalf("expected load/type-check error from invalid body-only change, got %q", got) + } +} + func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -1078,6 +1180,156 @@ func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { t.Fatalf("expected fast-path type-check error, got %q", got) } + if _, ok := readIncrementalManifest(incrementalManifestSelectorKey(root, env, []string{"./app"}, &GenerateOptions{})); ok { + t.Fatal("expected invalid incremental generate to invalidate selector manifest") + } +} + +func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + if len(errs) == 0 { + t.Fatal("expected invalid incremental generate to return errors") + } + clearIncrementalSessions() + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string; Count int }", + "", + "func NewMessage() string { return \"a\" }", + "", + "func NewCount() int { return 7 }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: msg, Count: count}", + "}", + "", + }, "\n")) + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + + var thirdLabels []string + thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { + thirdLabels = append(thirdLabels, label) + }) + third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("recovery incremental Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected recovery incremental Generate result: %+v", third) + } + + normal, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors: %v", errs) + } + if len(normal) != 1 || len(normal[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate result: %+v", normal) + } + if string(third[0].Content) != string(normal[0].Content) { + t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") + } + if !containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected recovery run to fall back to normal load after invalidating stale manifest, labels=%v", thirdLabels) + } } func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 4ef1f8f..04d9cb8 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -63,6 +63,7 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p if err != nil { debugf(ctx, "incremental.local_fastpath miss reason=%v", err) if shouldBypassIncrementalManifestAfterFastPathError(err) { + invalidateIncrementalPreloadState(state) return nil, true, true, []error{err} } return nil, false, false, nil @@ -104,6 +105,18 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p return generated, true, false, nil } +func validateIncrementalTouchedPackages(ctx context.Context, wd string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot) error { + if state == nil || state.manifest == nil || snapshot == nil || len(snapshot.touched) == 0 { + return nil + } + roots := manifestOutputPkgPaths(state.manifest) + if len(roots) != 1 { + return nil + } + _, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], snapshot.touched, snapshotPackageFingerprints(snapshot), state.manifest.ExternalPkgs) + return err +} + func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { if err == nil { return false @@ -115,6 +128,13 @@ func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { return strings.Contains(msg, "type-check failed for ") } +func invalidateIncrementalPreloadState(state *incrementalPreloadState) { + if state == nil { + return + } + removeIncrementalManifestFile(state.selectorKey) +} + func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { if len(errs) == 0 { return fmt.Errorf("type-check failed for %s", pkgPath) diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go index 223c9ad..f6137bc 100644 --- a/internal/wire/parser_lazy_loader.go +++ b/internal/wire/parser_lazy_loader.go @@ -128,7 +128,8 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { start := time.Now() isPrimary := isPrimaryFile(primary, filename) - if !isPrimary && ll.session != nil { + keepBodies := ll.shouldKeepDependencyBodies(filename) + if !isPrimary && !keepBodies && ll.session != nil { if file, ok := ll.session.getParsedDep(filename, src); ok { if stats != nil { stats.record(false, time.Since(start), nil, true) @@ -153,6 +154,9 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* if isPrimary { return file, nil } + if keepBodies { + return file, nil + } for _, decl := range file.Decls { if fn, ok := decl.(*ast.FuncDecl); ok { fn.Body = nil @@ -165,3 +169,20 @@ func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(* return file, nil } } + +func (ll *lazyLoader) shouldKeepDependencyBodies(filename string) bool { + if ll == nil || ll.fingerprints == nil || len(ll.fingerprints.touched) == 0 { + return false + } + clean := filepath.Clean(filename) + for _, pkgPath := range ll.fingerprints.touched { + files := ll.baseFiles[pkgPath] + if len(files) == 0 { + continue + } + if _, ok := files[clean]; ok { + return true + } + } + return false +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 8b617ea..1b6140f 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -129,6 +129,12 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if len(errs) > 0 { return nil, errs } + if err := validateIncrementalTouchedPackages(ctx, wd, opts, preloadState, loader.fingerprints); err != nil { + if shouldBypassIncrementalManifestAfterFastPathError(err) { + return nil, []error{err} + } + bypassIncrementalManifest = true + } if !bypassIncrementalManifest { if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { warmPackageOutputCache(pkgs, opts, cached) From c6e4f4e3babe2f5308e29568981b6c23b98e48d2 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 00:51:55 -0500 Subject: [PATCH 07/82] feat(incremental): harden loader and scenario tooling --- cmd/wire/gen_cmd.go | 4 +- cmd/wire/main.go | 41 +- cmd/wire/main_test.go | 31 +- cmd/wire/watch_cmd.go | 4 +- internal/wire/cache_coverage_test.go | 8 +- internal/wire/cache_key.go | 68 +- internal/wire/cache_manifest.go | 21 +- internal/wire/cache_scope.go | 69 + internal/wire/cache_scope_test.go | 59 + internal/wire/incremental_bench_test.go | 798 +++++++++++- internal/wire/incremental_fingerprint.go | 169 ++- internal/wire/incremental_fingerprint_test.go | 38 + internal/wire/incremental_graph.go | 4 +- internal/wire/incremental_manifest.go | 371 +++++- internal/wire/incremental_session.go | 2 +- internal/wire/incremental_summary.go | 2 +- internal/wire/loader_test.go | 1113 ++++++++++++++++- internal/wire/local_export.go | 97 ++ internal/wire/local_fastpath.go | 156 ++- internal/wire/wire.go | 4 + scripts/incremental-scenarios.sh | 137 ++ 21 files changed, 3076 insertions(+), 120 deletions(-) create mode 100644 internal/wire/cache_scope.go create mode 100644 internal/wire/cache_scope_test.go create mode 100644 internal/wire/local_export.go create mode 100755 scripts/incremental-scenarios.sh diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index 13b88ed..e98556f 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -112,9 +112,9 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa } if wrote, err := out.CommitWithStatus(); err == nil { if wrote { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } else { - log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) diff --git a/cmd/wire/main.go b/cmd/wire/main.go index d40e439..efaf767 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -38,8 +38,12 @@ import ( var topLevelIncremental optionalBoolFlag const ( - ansiRed = "\033[1;31m" - ansiReset = "\033[0m" + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " + maxLoggedErrorLines = 5 ) // main wires up subcommands and executes the selected command. @@ -209,7 +213,7 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { // logErrors logs each error with consistent formatting. func logErrors(errs []error) { for _, err := range errs { - msg := formatLoggedError(err) + msg := truncateLoggedError(formatLoggedError(err)) if strings.Contains(msg, "\n") { logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) continue @@ -232,10 +236,27 @@ func formatLoggedError(err error) string { return msg } +func truncateLoggedError(msg string) string { + if msg == "" { + return "" + } + lines := strings.Split(msg, "\n") + if len(lines) <= maxLoggedErrorLines { + return msg + } + omitted := len(lines) - maxLoggedErrorLines + lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) + return strings.Join(lines, "\n") +} + func logMultilineError(msg string) { writeErrorLog(os.Stderr, msg) } +func logSuccessf(format string, args ...interface{}) { + writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) +} + func shouldColorStderr() bool { return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) } @@ -266,7 +287,7 @@ func stderrIsTTY() bool { } func writeErrorLog(w io.Writer, msg string) { - line := "wire: " + msg + line := errorSig + "wire: " + msg if !strings.HasSuffix(line, "\n") { line += "\n" } @@ -277,6 +298,18 @@ func writeErrorLog(w io.Writer, msg string) { _, _ = io.WriteString(w, line) } +func writeStatusLog(w io.Writer, msg string) { + line := successSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, ansiGreen+line+ansiReset) + return + } + _, _ = io.WriteString(w, line) +} + func colorizeLines(s string) string { if s == "" { return "" diff --git a/cmd/wire/main_test.go b/cmd/wire/main_test.go index b172f62..7fe4720 100644 --- a/cmd/wire/main_test.go +++ b/cmd/wire/main_test.go @@ -2,6 +2,8 @@ package main import ( "bytes" + "fmt" + "strings" "testing" ) @@ -31,6 +33,19 @@ func TestFormatLoggedErrorLeavesNonSolveErrorsUnchanged(t *testing.T) { } } +func TestTruncateLoggedErrorSummarizesLargeBlocks(t *testing.T) { + lines := make([]string, 0, maxLoggedErrorLines+3) + for i := 0; i < maxLoggedErrorLines+3; i++ { + lines = append(lines, fmt.Sprintf("line %d", i+1)) + } + got := truncateLoggedError(strings.Join(lines, "\n")) + wantLines := append(append([]string(nil), lines[:maxLoggedErrorLines]...), "... (3 additional lines omitted)") + want := strings.Join(wantLines, "\n") + if got != want { + t.Fatalf("truncateLoggedError() = %q, want %q", got, want) + } +} + func TestShouldColorOutputForceColorOverridesTTYRequirement(t *testing.T) { t.Setenv("FORCE_COLOR", "1") t.Setenv("NO_COLOR", "") @@ -71,7 +86,7 @@ func TestWriteErrorLogFormatsWirePrefix(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "type-check failed for example.com/app/app") got := buf.String() - want := "wire: type-check failed for example.com/app/app\n" + want := errorSig + "wire: type-check failed for example.com/app/app\n" if got != want { t.Fatalf("writeErrorLog() = %q, want %q", got, want) } @@ -86,7 +101,7 @@ func TestWriteErrorLogColorsWholeBlockWhenForced(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "type-check failed for example.com/app/app") got := buf.String() - want := ansiRed + "wire: type-check failed for example.com/app/app\n" + ansiReset + want := ansiRed + errorSig + "wire: type-check failed for example.com/app/app\n" + ansiReset if got != want { t.Fatalf("writeErrorLog() = %q, want %q", got, want) } @@ -101,7 +116,7 @@ func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { var buf bytes.Buffer writeErrorLog(&buf, "\n first line\n second line") got := buf.String() - want := ansiRed + "wire: \n" + ansiReset + + want := ansiRed + errorSig + "wire: \n" + ansiReset + ansiRed + " first line\n" + ansiReset + ansiRed + " second line\n" + ansiReset if got != want { @@ -109,6 +124,16 @@ func TestWriteErrorLogColorsEachMultilineLineWhenForced(t *testing.T) { } } +func TestWriteStatusLogFormatsSuccessPrefix(t *testing.T) { + var buf bytes.Buffer + writeStatusLog(&buf, "example.com/app: wrote /tmp/wire_gen.go (12ms)") + got := buf.String() + want := successSig + "wire: example.com/app: wrote /tmp/wire_gen.go (12ms)\n" + if got != want { + t.Fatalf("writeStatusLog() = %q, want %q", got, want) + } +} + type testError string func (e testError) Error() string { return string(e) } diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index 13743cd..cb1b31b 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -131,9 +131,9 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } if wrote, err := out.CommitWithStatus(); err == nil { if wrote { - log.Printf("%s: wrote %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } else { - log.Printf("%s: unchanged %s (%s)\n", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) } } else { log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go index 316f30e..faf6e62 100644 --- a/internal/wire/cache_coverage_test.go +++ b/internal/wire/cache_coverage_test.go @@ -605,16 +605,18 @@ func TestManifestKeyHelpers(t *testing.T) { PrefixOutputFile: "prefix", Header: []byte("header"), } + wd := t.TempDir() + patterns := []string{"./a", "./b"} manifest := &cacheManifest{ - WD: t.TempDir(), + WD: runCacheScope(wd, patterns), EnvHash: envHash(env), Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), - Patterns: []string{"./a", "./b"}, + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), } got := manifestKeyFromManifest(manifest) - want := manifestKey(manifest.WD, env, manifest.Patterns, opts) + want := manifestKey(wd, env, patterns, opts) if got != want { t.Fatalf("manifest key mismatch: got %q, want %q", got, want) } diff --git a/internal/wire/cache_key.go b/internal/wire/cache_key.go index 2aa8881..f22c6c0 100644 --- a/internal/wire/cache_key.go +++ b/internal/wire/cache_key.go @@ -18,7 +18,9 @@ import ( "crypto/sha256" "fmt" "path/filepath" + "runtime" "sort" + "sync" "golang.org/x/tools/go/packages" ) @@ -209,17 +211,69 @@ func cacheMetaMatches(meta *cacheMeta, pkg *packages.Package, opts *GenerateOpti // buildCacheFiles converts file paths into cache metadata entries. func buildCacheFiles(files []string) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, name := range files { - info, err := osStat(name) + return buildCacheFilesWithStats(files, func(path string) (cacheFile, error) { + info, err := osStat(path) if err != nil { - return nil, err + return cacheFile{}, err } - out = append(out, cacheFile{ - Path: filepath.Clean(name), + return cacheFile{ + Path: filepath.Clean(path), Size: info.Size(), ModTime: info.ModTime().UnixNano(), - }) + }, nil + }) +} + +func buildCacheFilesWithStats[T any](items []T, stat func(T) (cacheFile, error)) ([]cacheFile, error) { + if len(items) == 0 { + return nil, nil + } + if len(items) == 1 { + file, err := stat(items[0]) + if err != nil { + return nil, err + } + return []cacheFile{file}, nil + } + out := make([]cacheFile, len(items)) + workers := runtime.GOMAXPROCS(0) + if workers < 4 { + workers = 4 + } + if workers > len(items) { + workers = len(items) + } + var ( + wg sync.WaitGroup + mu sync.Mutex + firstErr error + indexCh = make(chan int, len(items)) + ) + for i := range items { + indexCh <- i + } + close(indexCh) + wg.Add(workers) + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for i := range indexCh { + file, err := stat(items[i]) + if err != nil { + mu.Lock() + if firstErr == nil { + firstErr = err + } + mu.Unlock() + continue + } + out[i] = file + } + }() + } + wg.Wait() + if firstErr != nil { + return nil, firstErr } return out, nil } diff --git a/internal/wire/cache_manifest.go b/internal/wire/cache_manifest.go index 127aa55..57be68b 100644 --- a/internal/wire/cache_manifest.go +++ b/internal/wire/cache_manifest.go @@ -79,14 +79,15 @@ func writeManifest(wd string, env []string, patterns []string, opts *GenerateOpt return } key := manifestKey(wd, env, patterns, opts) + scope := runCacheScope(wd, patterns) manifest := &cacheManifest{ Version: cacheVersion, - WD: wd, + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), } manifest.ExtraFiles = extraCacheFiles(wd) for _, pkg := range pkgs { @@ -138,7 +139,7 @@ func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptio h := sha256.New() h.Write([]byte(cacheVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(runCacheScope(wd, patterns))) h.Write([]byte{0}) h.Write([]byte(envHash(env))) h.Write([]byte{0}) @@ -148,7 +149,7 @@ func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptio h.Write([]byte{0}) h.Write([]byte(headerHash(opts.Header))) h.Write([]byte{0}) - for _, p := range sortedStrings(patterns) { + for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { h.Write([]byte(p)) h.Write([]byte{0}) } @@ -293,19 +294,17 @@ func manifestValid(manifest *cacheManifest) bool { // buildCacheFilesFromMeta re-stats files to compare metadata. func buildCacheFilesFromMeta(files []cacheFile) ([]cacheFile, error) { - out := make([]cacheFile, 0, len(files)) - for _, file := range files { + return buildCacheFilesWithStats(files, func(file cacheFile) (cacheFile, error) { info, err := osStat(file.Path) if err != nil { - return nil, err + return cacheFile{}, err } - out = append(out, cacheFile{ + return cacheFile{ Path: filepath.Clean(file.Path), Size: info.Size(), ModTime: info.ModTime().UnixNano(), - }) - } - return out, nil + }, nil + }) } // extraCacheFiles returns Go module/workspace files affecting builds. diff --git a/internal/wire/cache_scope.go b/internal/wire/cache_scope.go new file mode 100644 index 0000000..fe161a7 --- /dev/null +++ b/internal/wire/cache_scope.go @@ -0,0 +1,69 @@ +package wire + +import ( + "path/filepath" + "sort" + "strings" +) + +func packageCacheScope(wd string) string { + if root := findModuleRoot(wd); root != "" { + return filepath.Clean(root) + } + return filepath.Clean(wd) +} + +func runCacheScope(wd string, patterns []string) string { + scopeRoot := packageCacheScope(wd) + normalized := normalizePatternsForScope(wd, scopeRoot, patterns) + if len(normalized) == 0 { + return scopeRoot + } + return scopeRoot + "\n" + strings.Join(normalized, "\n") +} + +func normalizePatternsForScope(wd string, scopeRoot string, patterns []string) []string { + if len(patterns) == 0 { + return nil + } + out := make([]string, 0, len(patterns)) + for _, pattern := range patterns { + out = append(out, normalizePatternForScope(wd, scopeRoot, pattern)) + } + sort.Strings(out) + return out +} + +func normalizePatternForScope(wd string, scopeRoot string, pattern string) string { + if pattern == "" { + return pattern + } + if filepath.IsAbs(pattern) || strings.HasPrefix(pattern, ".") { + abs := pattern + if !filepath.IsAbs(abs) { + abs = filepath.Join(wd, pattern) + } + abs = filepath.Clean(abs) + if scopeRoot != "" { + if rel, ok := pathWithinRoot(scopeRoot, abs); ok { + if rel == "." { + return "." + } + return filepath.ToSlash(rel) + } + } + return filepath.ToSlash(abs) + } + return pattern +} + +func pathWithinRoot(root string, path string) (string, bool) { + rel, err := filepath.Rel(filepath.Clean(root), filepath.Clean(path)) + if err != nil { + return "", false + } + if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { + return "", false + } + return rel, true +} diff --git a/internal/wire/cache_scope_test.go b/internal/wire/cache_scope_test.go new file mode 100644 index 0000000..9cc518b --- /dev/null +++ b/internal/wire/cache_scope_test.go @@ -0,0 +1,59 @@ +package wire + +import ( + "path/filepath" + "testing" +) + +func TestRunScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") + wireDir := filepath.Join(root, "wire") + writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") + + env := []string{"GOOS=darwin"} + opts := &GenerateOptions{Tags: "wireinject", PrefixOutputFile: "gen_"} + + rootKey := manifestKey(root, env, []string{"./wire"}, opts) + subdirKey := manifestKey(wireDir, env, []string{"."}, opts) + if rootKey != subdirKey { + t.Fatalf("manifestKey mismatch: root=%q subdir=%q", rootKey, subdirKey) + } + + rootIncrementalKey := incrementalManifestSelectorKey(root, env, []string{"./wire"}, opts) + subdirIncrementalKey := incrementalManifestSelectorKey(wireDir, env, []string{"."}, opts) + if rootIncrementalKey != subdirIncrementalKey { + t.Fatalf("incrementalManifestSelectorKey mismatch: root=%q subdir=%q", rootIncrementalKey, subdirIncrementalKey) + } +} + +func TestPackageScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { + root := t.TempDir() + writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") + wireDir := filepath.Join(root, "wire") + writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") + + rootFingerprintKey := incrementalFingerprintKey(root, "wireinject", "example.com/app/wire") + subdirFingerprintKey := incrementalFingerprintKey(wireDir, "wireinject", "example.com/app/wire") + if rootFingerprintKey != subdirFingerprintKey { + t.Fatalf("incrementalFingerprintKey mismatch: root=%q subdir=%q", rootFingerprintKey, subdirFingerprintKey) + } + + rootSummaryKey := incrementalSummaryKey(root, "wireinject", "example.com/app/wire") + subdirSummaryKey := incrementalSummaryKey(wireDir, "wireinject", "example.com/app/wire") + if rootSummaryKey != subdirSummaryKey { + t.Fatalf("incrementalSummaryKey mismatch: root=%q subdir=%q", rootSummaryKey, subdirSummaryKey) + } + + rootGraphKey := incrementalGraphKey(root, "wireinject", []string{"example.com/app/wire"}) + subdirGraphKey := incrementalGraphKey(wireDir, "wireinject", []string{"example.com/app/wire"}) + if rootGraphKey != subdirGraphKey { + t.Fatalf("incrementalGraphKey mismatch: root=%q subdir=%q", rootGraphKey, subdirGraphKey) + } + + rootSessionKey := sessionKey(root, []string{"GOOS=darwin"}, "wireinject") + subdirSessionKey := sessionKey(wireDir, []string{"GOOS=darwin"}, "wireinject") + if rootSessionKey != subdirSessionKey { + t.Fatalf("sessionKey mismatch: root=%q subdir=%q", rootSessionKey, subdirSessionKey) + } +} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go index 911a8c7..b300c4f 100644 --- a/internal/wire/incremental_bench_test.go +++ b/internal/wire/incremental_bench_test.go @@ -5,10 +5,12 @@ import ( "fmt" "os" "path/filepath" + "sort" "strconv" "strings" "testing" "time" + "unicode/utf8" ) const ( @@ -18,6 +20,38 @@ const ( var largeBenchmarkSizes = []int{10, 100, 1000} +type incrementalScenarioBenchmarkCase struct { + name string + mutate func(tb testing.TB, root string) + measure func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace + wantErr bool +} + +type incrementalScenarioTrace struct { + total time.Duration + labels map[string]time.Duration +} + +type incrementalScenarioBudget struct { + total time.Duration + validateLocal time.Duration + validateExt time.Duration + validateTouch time.Duration + validateTouchHit time.Duration + outputs time.Duration + generateLoad time.Duration + localFastpath time.Duration +} + +type largeRepoPerformanceBudget struct { + shapeTotal time.Duration + localLoad time.Duration + parse time.Duration + typecheck time.Duration + generate time.Duration + knownToggle time.Duration +} + func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { cacheHooksMu.Lock() state := saveCacheHooks() @@ -77,6 +111,129 @@ func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { } } +func BenchmarkGenerateIncrementalScenarioMatrix(b *testing.B) { + cacheHooksMu.Lock() + state := saveCacheHooks() + b.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(b) + for _, scenario := range incrementalScenarioBenchmarks() { + scenario := scenario + b.Run(scenario.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + b.StartTimer() + _ = measureIncrementalScenarioOnce(b, repoRoot, scenario) + b.StopTimer() + } + }) + } +} + +func TestPrintIncrementalScenarioBenchmarkTable(t *testing.T) { + if os.Getenv("WIRE_BENCH_SCENARIOS") == "" { + t.Skip("set WIRE_BENCH_SCENARIOS=1 to print the incremental scenario benchmark table") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + rows := [][]string{{ + "scenario", + "total", + "local pkgs", + "external", + "touched", + "touch hit", + "outputs", + "gen load", + "local fastpath", + }} + for _, scenario := range incrementalScenarioBenchmarks() { + trace := measureIncrementalScenarioOnce(t, repoRoot, scenario) + rows = append(rows, []string{ + scenario.name, + formatBenchmarkDuration(trace.total), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_local_packages")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_external_files")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched_cache_hit")), + formatBenchmarkDuration(trace.label("incremental.preload_manifest.outputs")), + formatBenchmarkDuration(trace.label("generate.load")), + formatBenchmarkDuration(trace.label("incremental.local_fastpath.load")), + }) + } + fmt.Print(renderASCIITable(rows)) +} + +func TestIncrementalScenarioPerformanceBudgets(t *testing.T) { + if os.Getenv("WIRE_PERF_BUDGETS") == "" { + t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + budgets := incrementalScenarioPerformanceBudgets() + for _, scenario := range incrementalScenarioBenchmarks() { + scenario := scenario + budget, ok := budgets[scenario.name] + if !ok { + t.Fatalf("missing performance budget for scenario %q", scenario.name) + } + t.Run(scenario.name, func(t *testing.T) { + trace := measureIncrementalScenarioMedian(t, repoRoot, scenario, 5) + assertScenarioBudget(t, trace, budget) + }) + } +} + +func TestLargeRepoPerformanceBudgets(t *testing.T) { + if os.Getenv("WIRE_PERF_BUDGETS") == "" { + t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") + } + + cacheHooksMu.Lock() + state := saveCacheHooks() + t.Cleanup(func() { + restoreCacheHooks(state) + cacheHooksMu.Unlock() + }) + + repoRoot := benchmarkRepoRoot(t) + budgets := largeRepoPerformanceBudgets() + for _, packageCount := range largeBenchmarkSizes { + packageCount := packageCount + budget, ok := budgets[packageCount] + if !ok { + t.Fatalf("missing large-repo performance budget for size %d", packageCount) + } + t.Run(strconv.Itoa(packageCount), func(t *testing.T) { + trace := measureLargeRepoShapeChangeTraceMedian(t, repoRoot, packageCount, true, 3) + checkBudgetDuration(t, "shape_total", trace.total, budget.shapeTotal) + checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localLoad) + checkBudgetDuration(t, "parse", trace.label("incremental.local_fastpath.parse"), budget.parse) + checkBudgetDuration(t, "typecheck", trace.label("incremental.local_fastpath.typecheck"), budget.typecheck) + checkBudgetDuration(t, "generate", trace.label("incremental.local_fastpath.generate"), budget.generate) + + knownToggle := measureLargeRepoKnownToggleMedian(t, repoRoot, packageCount, 3) + checkBudgetDuration(t, "known_toggle", knownToggle, budget.knownToggle) + }) + } +} + func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { runLargeRepoShapeChangeBenchmarks(b, false) } @@ -164,8 +321,10 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { "old typed load", "new total", "new local load", - "new cached sets", + "new parse", + "new typecheck", "new injector solve", + "new format", "new generate", "speedup", }} @@ -179,8 +338,10 @@ func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { formatBenchmarkDuration(normal.label("load.packages.lazy.load")), formatBenchmarkDuration(incremental.total), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.summary_resolve")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.parse")), + formatBenchmarkDuration(incremental.label("incremental.local_fastpath.typecheck")), formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), + formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.format")), formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), }) @@ -314,6 +475,576 @@ func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { } } +func incrementalScenarioBenchmarks() []incrementalScenarioBenchmarkCase { + return []incrementalScenarioBenchmarkCase{ + { + name: "preload_unchanged", + mutate: func(testing.TB, string) {}, + }, + { + name: "preload_whitespace_only_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "", + "func New(msg string) *Foo {", + "", + "\treturn &Foo{Message: helper(msg)}", + "", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_body_only_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string {", + "\treturn helper(SQLText)", + "}", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_body_only_repeat_change", + measure: func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace { + writeBodyOnlyScenarioVariant(tb, root, "b") + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("warm changed variant Generate returned errors: %v", errs) + } + writeBodyOnlyScenarioVariant(tb, root, "a") + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("reset variant Generate returned errors: %v", errs) + } + writeBodyOnlyScenarioVariant(tb, root, "b") + trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} + timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + if len(errs) > 0 { + tb.Fatalf("%s: Generate returned errors: %v", "preload_body_only_repeat_change", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("%s: unexpected Generate results: %+v", "preload_body_only_repeat_change", gens) + } + return trace + }, + }, + { + name: "local_fastpath_method_body_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func (f Foo) Summary() string {", + "\treturn helper(f.Message)", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_const_value_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_var_initializer_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 2", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_add_top_level_helper", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func NewTag() string { return \"tag\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "preload_import_only_implementation_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return fmt.Sprint(msg) }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_signature_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 7", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func NewCount() int { return defaultCount }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: count}", + "}", + "", + }, "\n")) + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_struct_field_addition", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: defaultCount}", + "}", + "", + }, "\n")) + }, + }, + { + name: "local_fastpath_interface_method_addition", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Fooer interface {", + "\tMessage() string", + "\tCount() int", + "}", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + }, + { + name: "fallback_invalid_body_change", + mutate: func(tb testing.TB, root string) { + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return missing }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + wantErr: true, + }, + } +} + +func incrementalScenarioPerformanceBudgets() map[string]incrementalScenarioBudget { + return map[string]incrementalScenarioBudget{ + "preload_unchanged": { + total: 300 * time.Millisecond, + validateLocal: 25 * time.Millisecond, + validateExt: 25 * time.Millisecond, + validateTouch: 5 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_whitespace_only_change": { + total: 300 * time.Millisecond, + validateLocal: 25 * time.Millisecond, + validateExt: 25 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_body_only_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_body_only_repeat_change": { + total: 150 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 5 * time.Millisecond, + validateTouchHit: 5 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "local_fastpath_method_body_change": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "preload_import_only_implementation_change": { + total: 150 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 50 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_const_value_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "preload_var_initializer_change": { + total: 400 * time.Millisecond, + validateLocal: 40 * time.Millisecond, + validateExt: 40 * time.Millisecond, + validateTouch: 250 * time.Millisecond, + outputs: 5 * time.Millisecond, + }, + "local_fastpath_add_top_level_helper": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_signature_change": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_struct_field_addition": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "local_fastpath_interface_method_addition": { + total: 500 * time.Millisecond, + validateLocal: 60 * time.Millisecond, + validateExt: 60 * time.Millisecond, + localFastpath: 300 * time.Millisecond, + }, + "fallback_invalid_body_change": { + total: 800 * time.Millisecond, + generateLoad: 500 * time.Millisecond, + }, + } +} + +func measureIncrementalScenarioOnce(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase) incrementalScenarioTrace { + tb.Helper() + + cacheRoot := tb.TempDir() + osTempDir = func() string { return cacheRoot } + + root := tb.TempDir() + writeIncrementalScenarioBenchmarkModule(tb, repoRoot, root) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + tb.Fatalf("baseline Generate returned errors: %v", errs) + } + + if scenario.measure != nil { + return scenario.measure(tb, root, env, ctx) + } + + scenario.mutate(tb, root) + + trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} + timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { + trace.labels[label] += dur + }) + start := time.Now() + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + trace.total = time.Since(start) + + if scenario.wantErr { + if len(errs) == 0 { + tb.Fatalf("%s: expected Generate errors", scenario.name) + } + if len(gens) != 0 { + tb.Fatalf("%s: expected no generated results on error, got %+v", scenario.name, gens) + } + return trace + } + + if len(errs) > 0 { + tb.Fatalf("%s: Generate returned errors: %v", scenario.name, errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + tb.Fatalf("%s: unexpected Generate results: %+v", scenario.name, gens) + } + return trace +} + +func writeIncrementalScenarioBenchmarkModule(tb testing.TB, repoRoot string, root string) { + tb.Helper() + + writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeBodyOnlyScenarioVariant(tb, root, "green") +} + +func writeBodyOnlyScenarioVariant(tb testing.TB, root string, value string) { + tb.Helper() + writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "const SQLText = \"" + value + "\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + + writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) +} + +func measureIncrementalScenarioMedian(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase, samples int) incrementalScenarioTrace { + tb.Helper() + if samples <= 0 { + samples = 1 + } + traces := make([]incrementalScenarioTrace, 0, samples) + for i := 0; i < samples; i++ { + traces = append(traces, measureIncrementalScenarioOnce(tb, repoRoot, scenario)) + } + sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) + return traces[len(traces)/2] +} + +func assertScenarioBudget(t *testing.T, trace incrementalScenarioTrace, budget incrementalScenarioBudget) { + t.Helper() + checkBudgetDuration(t, "total", trace.total, budget.total) + checkBudgetDuration(t, "validate_local_packages", trace.label("incremental.preload_manifest.validate_local_packages"), budget.validateLocal) + checkBudgetDuration(t, "validate_external_files", trace.label("incremental.preload_manifest.validate_external_files"), budget.validateExt) + checkBudgetDuration(t, "validate_touched", trace.label("incremental.preload_manifest.validate_touched"), budget.validateTouch) + checkBudgetDuration(t, "validate_touched_cache_hit", trace.label("incremental.preload_manifest.validate_touched_cache_hit"), budget.validateTouchHit) + checkBudgetDuration(t, "outputs", trace.label("incremental.preload_manifest.outputs"), budget.outputs) + checkBudgetDuration(t, "generate_load", trace.label("generate.load"), budget.generateLoad) + checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localFastpath) +} + +func checkBudgetDuration(t *testing.T, name string, got time.Duration, max time.Duration) { + t.Helper() + if max <= 0 { + return + } + if got > max { + t.Fatalf("%s exceeded budget: got=%s max=%s", name, got, max) + } +} + +func (s incrementalScenarioTrace) label(name string) time.Duration { + if s.labels == nil { + return 0 + } + return s.labels[name] +} + type largeRepoBenchmarkRow struct { packageCount int coldNormal time.Duration @@ -328,6 +1059,35 @@ type shapeChangeTrace struct { labels map[string]time.Duration } +func largeRepoPerformanceBudgets() map[int]largeRepoPerformanceBudget { + return map[int]largeRepoPerformanceBudget{ + 10: { + shapeTotal: 45 * time.Millisecond, + localLoad: 3 * time.Millisecond, + parse: 500 * time.Microsecond, + typecheck: 4 * time.Millisecond, + generate: 3 * time.Millisecond, + knownToggle: 3 * time.Millisecond, + }, + 100: { + shapeTotal: 35 * time.Millisecond, + localLoad: 20 * time.Millisecond, + parse: 1500 * time.Microsecond, + typecheck: 12 * time.Millisecond, + generate: 20 * time.Millisecond, + knownToggle: 15 * time.Millisecond, + }, + 1000: { + shapeTotal: 260 * time.Millisecond, + localLoad: 110 * time.Millisecond, + parse: 4 * time.Millisecond, + typecheck: 70 * time.Millisecond, + generate: 180 * time.Millisecond, + knownToggle: 90 * time.Millisecond, + }, + } +} + func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { tb.Helper() @@ -394,6 +1154,19 @@ func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packag return trace } +func measureLargeRepoShapeChangeTraceMedian(tb testing.TB, repoRoot string, packageCount int, incremental bool, samples int) shapeChangeTrace { + tb.Helper() + if samples <= 0 { + samples = 1 + } + traces := make([]shapeChangeTrace, 0, samples) + for i := 0; i < samples; i++ { + traces = append(traces, measureLargeRepoShapeChangeTraceOnce(tb, repoRoot, packageCount, incremental)) + } + sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) + return traces[len(traces)/2] +} + func (s shapeChangeTrace) label(name string) time.Duration { if s.labels == nil { return 0 @@ -466,6 +1239,19 @@ func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCoun return dur } +func measureLargeRepoKnownToggleMedian(tb testing.TB, repoRoot string, packageCount int, samples int) time.Duration { + tb.Helper() + if samples <= 0 { + samples = 1 + } + values := make([]time.Duration, 0, samples) + for i := 0; i < samples; i++ { + values = append(values, measureLargeRepoKnownToggleOnce(tb, repoRoot, packageCount)) + } + sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) + return values[len(values)/2] +} + func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { if normal <= 0 { return "0.0%" @@ -488,7 +1274,7 @@ func formatBenchmarkDuration(d time.Duration) string { case d >= time.Millisecond: return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) case d >= time.Microsecond: - return fmt.Sprintf("%.2fµs", float64(d)/float64(time.Microsecond)) + return fmt.Sprintf("%.2fus", float64(d)/float64(time.Microsecond)) default: return d.String() } @@ -674,8 +1460,8 @@ func renderASCIITable(rows [][]string) string { widths := make([]int, len(rows[0])) for _, row := range rows { for i, cell := range row { - if len(cell) > widths[i] { - widths[i] = len(cell) + if width := utf8.RuneCountInString(cell); width > widths[i] { + widths[i] = width } } } @@ -693,7 +1479,7 @@ func renderASCIITable(rows [][]string) string { for i, cell := range row { b.WriteByte(' ') b.WriteString(cell) - b.WriteString(strings.Repeat(" ", widths[i]-len(cell)+1)) + b.WriteString(strings.Repeat(" ", widths[i]-utf8.RuneCountInString(cell)+1)) b.WriteByte('|') } b.WriteByte('\n') diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go index 42c2317..be39982 100644 --- a/internal/wire/incremental_fingerprint.go +++ b/internal/wire/incremental_fingerprint.go @@ -31,7 +31,7 @@ import ( "golang.org/x/tools/go/packages" ) -const incrementalFingerprintVersion = "wire-incremental-v1" +const incrementalFingerprintVersion = "wire-incremental-v3" type packageFingerprint struct { Version string @@ -39,6 +39,8 @@ type packageFingerprint struct { Tags string PkgPath string Files []cacheFile + Dirs []cacheFile + ContentHash string ShapeHash string LocalImports []string } @@ -159,10 +161,12 @@ func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs [ sort.Strings(localImports) snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ Version: incrementalFingerprintVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, PkgPath: pkg.PkgPath, Files: metaFiles, + Dirs: mustBuildPackageDirCacheFiles(files), + ContentHash: mustHashPackageFiles(files), ShapeHash: shapeHash, LocalImports: localImports, } @@ -183,11 +187,52 @@ func packageFingerprintFiles(pkg *packages.Package) []string { return append([]string(nil), pkg.GoFiles...) } +func packageFingerprintDirs(files []string) []string { + if len(files) == 0 { + return nil + } + dirs := make([]string, 0, len(files)) + seen := make(map[string]struct{}, len(files)) + for _, name := range files { + dir := filepath.Clean(filepath.Dir(name)) + if _, ok := seen[dir]; ok { + continue + } + seen[dir] = struct{}{} + dirs = append(dirs, dir) + } + sort.Strings(dirs) + return dirs +} + +func mustBuildPackageDirCacheFiles(files []string) []cacheFile { + dirs := packageFingerprintDirs(files) + if len(dirs) == 0 { + return nil + } + meta, err := buildCacheFiles(dirs) + if err != nil { + return nil + } + return meta +} + +func mustHashPackageFiles(files []string) string { + if len(files) == 0 { + return "" + } + hash, err := hashFiles(files) + if err != nil { + return "" + } + return hash +} + func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { if a == nil || b == nil { return false } - if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || filepath.Clean(a.WD) != filepath.Clean(b.WD) { + if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || a.WD != b.WD { return false } if len(a.LocalImports) != len(b.LocalImports) { @@ -205,7 +250,7 @@ func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags if prev == nil || prev.Version != incrementalFingerprintVersion { return false } - if filepath.Clean(prev.WD) != filepath.Clean(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { + if prev.WD != packageCacheScope(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { return false } if len(prev.Files) != len(files) { @@ -234,10 +279,12 @@ func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, file sort.Strings(localImports) return &packageFingerprint{ Version: incrementalFingerprintVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, PkgPath: pkg.PkgPath, Files: append([]cacheFile(nil), files...), + Dirs: mustBuildPackageDirCacheFiles(packageFingerprintFiles(pkg)), + ContentHash: mustHashPackageFiles(packageFingerprintFiles(pkg)), ShapeHash: shapeHash, LocalImports: localImports, }, nil @@ -278,6 +325,7 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File if file == nil || buf == nil || fset == nil { return } + usedImports := usedImportNamesInShape(file) if file.Name != nil { buf.WriteString("package ") buf.WriteString(file.Name.Name) @@ -294,6 +342,10 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File buf.WriteByte(' ') writeNodeHash(buf, fset, decl.Type) buf.WriteByte('\n') + case *ast.GenDecl: + if writeGenDeclShapeHash(buf, fset, decl, usedImports) { + buf.WriteByte('\n') + } default: writeNodeHash(buf, fset, decl) buf.WriteByte('\n') @@ -301,6 +353,111 @@ func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File } } +func writeGenDeclShapeHash(buf *bytes.Buffer, fset *token.FileSet, decl *ast.GenDecl, usedImports map[string]struct{}) bool { + if buf == nil || fset == nil || decl == nil { + return false + } + var specBuf bytes.Buffer + wrote := false + for _, spec := range decl.Specs { + switch spec := spec.(type) { + case *ast.ImportSpec: + name := importName(spec) + if name == "_" || name == "." { + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Path) + specBuf.WriteByte('\n') + wrote = true + break + } + if _, ok := usedImports[name]; !ok { + continue + } + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Path) + case *ast.TypeSpec: + if spec.Name != nil { + specBuf.WriteString(spec.Name.Name) + } + specBuf.WriteByte(' ') + writeNodeHash(&specBuf, fset, spec.Type) + case *ast.ValueSpec: + for _, name := range spec.Names { + if name != nil { + specBuf.WriteString(name.Name) + } + specBuf.WriteByte(' ') + } + if spec.Type != nil { + writeNodeHash(&specBuf, fset, spec.Type) + } + default: + writeNodeHash(&specBuf, fset, spec) + } + specBuf.WriteByte('\n') + wrote = true + } + if !wrote { + return false + } + buf.WriteString(decl.Tok.String()) + buf.WriteByte(' ') + buf.Write(specBuf.Bytes()) + return true +} + +func usedImportNamesInShape(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + if file == nil { + return used + } + record := func(node ast.Node) { + ast.Inspect(node, func(n ast.Node) bool { + sel, ok := n.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + } + for _, decl := range file.Decls { + switch decl := decl.(type) { + case *ast.FuncDecl: + if decl.Recv != nil { + record(decl.Recv) + } + if decl.Type != nil { + record(decl.Type) + } + case *ast.GenDecl: + for _, spec := range decl.Specs { + switch spec := spec.(type) { + case *ast.TypeSpec: + if spec.Type != nil { + record(spec.Type) + } + case *ast.ValueSpec: + if spec.Type != nil { + record(spec.Type) + } + } + } + } + } + return used +} + func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { if buf == nil || fset == nil || node == nil { return @@ -324,7 +481,7 @@ func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { h := sha256.New() h.Write([]byte(incrementalFingerprintVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go index afe81de..920d08e 100644 --- a/internal/wire/incremental_fingerprint_test.go +++ b/internal/wire/incremental_fingerprint_test.go @@ -41,6 +41,44 @@ func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { } } +func TestPackageShapeHashIgnoresConstValueChanges(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nconst SQLText = \"a\"\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nconst SQLText = \"b\"\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("const-value change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + +func TestPackageShapeHashIgnoresImplementationOnlyImportChanges(t *testing.T) { + dir := t.TempDir() + file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") + hash1, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash first failed: %v", err) + } + if err := os.WriteFile(file, []byte("package p\n\nimport \"fmt\"\n\nfunc Hello() string { return fmt.Sprint(\"a\") }\n"), 0644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + hash2, err := packageShapeHash([]string{file}) + if err != nil { + t.Fatalf("packageShapeHash second failed: %v", err) + } + if hash1 != hash2 { + t.Fatalf("implementation-only import change should not affect shape hash: %q vs %q", hash1, hash2) + } +} + func TestIncrementalFingerprintRoundTrip(t *testing.T) { fp := &packageFingerprint{ Version: incrementalFingerprintVersion, diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go index 66cf28d..37b3d0f 100644 --- a/internal/wire/incremental_graph.go +++ b/internal/wire/incremental_graph.go @@ -58,7 +58,7 @@ func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *in moduleRoot := findModuleRoot(wd) graph := &incrementalGraph{ Version: incrementalGraphVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, Roots: make([]string, 0, len(pkgs)), LocalReverse: make(map[string][]string), @@ -126,7 +126,7 @@ func incrementalGraphKey(wd string, tags string, roots []string) string { h := sha256.New() h.Write([]byte(incrementalGraphVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go index 4a55c19..8fab10e 100644 --- a/internal/wire/incremental_manifest.go +++ b/internal/wire/incremental_manifest.go @@ -20,6 +20,7 @@ import ( "crypto/sha256" "encoding/binary" "fmt" + "go/token" "os" "path/filepath" "sort" @@ -28,7 +29,7 @@ import ( "golang.org/x/tools/go/packages" ) -const incrementalManifestVersion = "wire-incremental-manifest-v1" +const incrementalManifestVersion = "wire-incremental-manifest-v3" type incrementalManifest struct { Version string @@ -61,9 +62,19 @@ type incrementalPreloadState struct { manifest *incrementalManifest valid bool currentLocal []packageFingerprint + touched []string reason string } +type incrementalPreloadValidation struct { + valid bool + currentLocal []packageFingerprint + touched []string + reason string +} + +const touchedValidationVersion = "wire-touched-validation-v1" + func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) @@ -75,15 +86,36 @@ func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd stri return nil, false } if state.valid { + validateStart := timeNow() + if len(state.touched) > 0 { + debugf(ctx, "incremental.preload_manifest touched=%s", strings.Join(state.touched, ",")) + } + if err := validateIncrementalPreloadTouchedPackages(ctx, wd, env, opts, state.currentLocal, state.touched); err != nil { + logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) + if shouldBypassIncrementalManifestAfterFastPathError(err) { + invalidateIncrementalPreloadState(state) + } + debugf(ctx, "incremental.preload_manifest miss reason=touched_validation") + return nil, false + } + logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) + outputsStart := timeNow() results, ok := incrementalManifestOutputs(state.manifest) + logTiming(ctx, "incremental.preload_manifest.outputs", outputsStart) if !ok { debugf(ctx, "incremental.preload_manifest miss reason=outputs") return nil, false } + if manifestNeedsLocalRefresh(state.manifest.LocalPackages, state.currentLocal) { + refreshed := *state.manifest + refreshed.LocalPackages = append([]packageFingerprint(nil), state.currentLocal...) + writeIncrementalManifestFile(state.selectorKey, &refreshed) + writeIncrementalManifestFile(incrementalManifestStateKey(state.selectorKey, refreshed.LocalPackages), &refreshed) + } debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) return results, true } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { - if ok, _, _ := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); ok { + if validation := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); validation.valid { results, ok := incrementalManifestOutputs(archived) if !ok { debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") @@ -107,13 +139,14 @@ func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string if !ok { return nil, false } - valid, currentLocal, reason := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) + validation := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) return &incrementalPreloadState{ selectorKey: selectorKey, manifest: manifest, - valid: valid, - currentLocal: currentLocal, - reason: reason, + valid: validation.valid, + currentLocal: validation.currentLocal, + touched: validation.touched, + reason: validation.reason, }, true } @@ -150,6 +183,7 @@ func writeIncrementalManifestWithOptions(wd string, env []string, patterns []str if snapshot == nil || len(generated) == 0 { return } + scope := runCacheScope(wd, patterns) externalPkgs := buildExternalPackageExports(wd, pkgs) var externalFiles []cacheFile if includeExternalFiles { @@ -161,12 +195,12 @@ func writeIncrementalManifestWithOptions(wd string, env []string, patterns []str } manifest := &incrementalManifest{ Version: incrementalManifestVersion, - WD: filepath.Clean(wd), + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), LocalPackages: snapshotPackageFingerprints(snapshot), ExternalPkgs: externalPkgs, ExternalFiles: externalFiles, @@ -197,7 +231,7 @@ func incrementalManifestSelectorKey(wd string, env []string, patterns []string, h := sha256.New() h.Write([]byte(incrementalManifestVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(runCacheScope(wd, patterns))) h.Write([]byte{0}) h.Write([]byte(envHash(env))) h.Write([]byte{0}) @@ -207,7 +241,7 @@ func incrementalManifestSelectorKey(wd string, env []string, patterns []string, h.Write([]byte{0}) h.Write([]byte(headerHash(opts.Header))) h.Write([]byte{0}) - for _, p := range sortedStrings(patterns) { + for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { h.Write([]byte(p)) h.Write([]byte{0}) } @@ -276,16 +310,17 @@ func incrementalManifestValid(manifest *incrementalManifest, wd string, env []st if manifest == nil || manifest.Version != incrementalManifestVersion { return false } - if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { return false } if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { return false } - if len(manifest.Patterns) != len(patterns) { + normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) + if len(manifest.Patterns) != len(normalizedPatterns) { return false } - for i, p := range sortedStrings(patterns) { + for i, p := range normalizedPatterns { if manifest.Patterns[i] != p { return false } @@ -313,58 +348,93 @@ func incrementalManifestValid(manifest *incrementalManifest, wd string, env []st return len(manifest.Outputs) > 0 } -func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) (bool, []packageFingerprint, string) { +func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) incrementalPreloadValidation { if manifest == nil || manifest.Version != incrementalManifestVersion { - return false, nil, "version" + return incrementalPreloadValidation{reason: "version"} } - if filepath.Clean(manifest.WD) != filepath.Clean(wd) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return false, nil, "config" + if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { + return incrementalPreloadValidation{reason: "config"} } if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return false, nil, "env" + return incrementalPreloadValidation{reason: "env"} } - if len(manifest.Patterns) != len(patterns) { - return false, nil, "patterns.length" + normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) + if len(manifest.Patterns) != len(normalizedPatterns) { + return incrementalPreloadValidation{reason: "patterns.length"} } - for i, p := range sortedStrings(patterns) { + for i, p := range normalizedPatterns { if manifest.Patterns[i] != p { - return false, nil, "patterns.value" + return incrementalPreloadValidation{reason: "patterns.value"} } } if len(manifest.ExtraFiles) > 0 { + extraStart := timeNow() current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) + logTiming(ctx, "incremental.preload_manifest.validate_extra_files", extraStart) if err != nil || len(current) != len(manifest.ExtraFiles) { - return false, nil, "extra_files" + return incrementalPreloadValidation{reason: "extra_files"} } for i := range current { if current[i] != manifest.ExtraFiles[i] { - return false, nil, "extra_files.diff" + return incrementalPreloadValidation{reason: "extra_files.diff"} } } } - currentLocal, ok, reason := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) - if !ok { - return false, currentLocal, "local_packages." + reason + localStart := timeNow() + packagesState := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) + logTiming(ctx, "incremental.preload_manifest.validate_local_packages", localStart) + if !packagesState.valid { + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "local_packages." + packagesState.reason, + } } if len(manifest.ExternalFiles) > 0 { + externalStart := timeNow() current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) + logTiming(ctx, "incremental.preload_manifest.validate_external_files", externalStart) if err != nil || len(current) != len(manifest.ExternalFiles) { - return false, currentLocal, "external_files" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "external_files", + } } for i := range current { if current[i] != manifest.ExternalFiles[i] { - return false, currentLocal, "external_files.diff" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "external_files.diff", + } } } } if len(manifest.Outputs) == 0 { - return false, currentLocal, "outputs" + return incrementalPreloadValidation{ + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, + reason: "outputs", + } + } + return incrementalPreloadValidation{ + valid: true, + currentLocal: packagesState.currentLocal, + touched: packagesState.touched, } - return true, currentLocal, "" } -func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) ([]packageFingerprint, bool, string) { +type incrementalLocalPackagesState struct { + valid bool + currentLocal []packageFingerprint + touched []string + reason string +} + +func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) incrementalLocalPackagesState { currentState := make([]packageFingerprint, 0, len(local)) + touched := make([]string, 0, len(local)) var firstReason string for _, fp := range local { if len(fp.Files) == 0 { @@ -400,42 +470,158 @@ func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packag } } if !sameMeta { - shapeHash, err := packageShapeHash(storedFiles) + if diffs := describeCacheFileDiffs(fp.Files, currentMeta); len(diffs) > 0 { + debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_diff=%s", fp.PkgPath, strings.Join(diffs, "; ")) + } + contentHash, err := hashFiles(storedFiles) if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + debugf(ctx, "incremental.preload_manifest local_pkg=%s content_error=%v", fp.PkgPath, err) if firstReason == "" { - firstReason = fp.PkgPath + ".shape_error" + firstReason = fp.PkgPath + ".content_error" } continue } - currentFP.ShapeHash = shapeHash - if shapeHash != fp.ShapeHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_mismatch" + currentFP.ContentHash = contentHash + if contentHash != fp.ContentHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_content=%s current_content=%s hash_files=%s", fp.PkgPath, fp.ContentHash, contentHash, strings.Join(storedFiles, ",")) + shapeHash, err := packageShapeHash(storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_error" + } + continue + } + currentFP.ShapeHash = shapeHash + if shapeHash != fp.ShapeHash { + debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) + if firstReason == "" { + firstReason = fp.PkgPath + ".shape_mismatch" + } + } else { + debugf(ctx, "incremental.preload_manifest local_pkg=%s content_changed_shape_unchanged", fp.PkgPath) + touched = append(touched, fp.PkgPath) } - } else if firstReason == "" { - firstReason = fp.PkgPath + ".meta_changed" } } - if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + currentDirs, dirsChanged, err := packageDirectoryMetaChanged(fp, storedFiles) + if err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_meta_error=%v", fp.PkgPath, err) if firstReason == "" { - firstReason = fp.PkgPath + ".dir_scan_error" + firstReason = fp.PkgPath + ".dir_meta_error" } continue - } else if changed { - debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) - if firstReason == "" { - firstReason = fp.PkgPath + ".introduced_relevant_files" + } + currentFP.Dirs = currentDirs + if dirsChanged { + if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { + debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) + if firstReason == "" { + firstReason = fp.PkgPath + ".dir_scan_error" + } + continue + } else if changed { + debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) + if firstReason == "" { + firstReason = fp.PkgPath + ".introduced_relevant_files" + } } } currentState = append(currentState, currentFP) } if firstReason != "" { - return currentState, false, firstReason + return incrementalLocalPackagesState{ + currentLocal: currentState, + touched: touched, + reason: firstReason, + } + } + sort.Strings(touched) + return incrementalLocalPackagesState{ + valid: true, + currentLocal: currentState, + touched: touched, + } +} + +func validateIncrementalPreloadTouchedPackages(ctx context.Context, wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) error { + if len(touched) == 0 { + return nil + } + cacheKey := touchedValidationKey(wd, env, opts, local, touched) + if cacheKey != "" { + cacheHitStart := timeNow() + if _, ok := readCache(cacheKey); ok { + logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_hit", cacheHitStart) + return nil + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, + Dir: wd, + Env: env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: token.NewFileSet(), + } + if len(opts.Tags) > 0 { + cfg.BuildFlags[0] += " " + opts.Tags + } + loadStart := timeNow() + pkgs, err := packages.Load(cfg, touched...) + logTiming(ctx, "incremental.preload_manifest.validate_touched_load", loadStart) + if err != nil { + return err + } + errorsStart := timeNow() + byPath := make(map[string]*packages.Package, len(pkgs)) + for _, pkg := range pkgs { + if pkg != nil { + byPath[pkg.PkgPath] = pkg + } + } + for _, path := range touched { + if pkg := byPath[path]; pkg != nil && len(pkg.Errors) > 0 { + logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) + return formatLocalTypeCheckError(wd, pkg.PkgPath, pkg.Errors) + } } - return currentState, true, "" + logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) + if cacheKey != "" { + cacheWriteStart := timeNow() + writeCache(cacheKey, []byte("ok")) + logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_write", cacheWriteStart) + } + return nil +} + +func touchedValidationKey(wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) string { + if len(touched) == 0 { + return "" + } + byPath := fingerprintsFromSlice(local) + h := sha256.New() + h.Write([]byte(touchedValidationVersion)) + h.Write([]byte{0}) + h.Write([]byte(packageCacheScope(wd))) + h.Write([]byte{0}) + h.Write([]byte(envHash(env))) + h.Write([]byte{0}) + if opts != nil { + h.Write([]byte(opts.Tags)) + } + h.Write([]byte{0}) + for _, pkgPath := range touched { + fp := byPath[pkgPath] + if fp == nil || fp.ContentHash == "" { + return "" + } + h.Write([]byte(pkgPath)) + h.Write([]byte{0}) + h.Write([]byte(fp.ContentHash)) + h.Write([]byte{0}) + } + return fmt.Sprintf("%x", h.Sum(nil)) } func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { @@ -500,6 +686,89 @@ func filesFromMeta(files []cacheFile) []string { return out } +func describeCacheFileDiffs(stored []cacheFile, current []cacheFile) []string { + if len(stored) == 0 && len(current) == 0 { + return nil + } + storedByPath := make(map[string]cacheFile, len(stored)) + currentByPath := make(map[string]cacheFile, len(current)) + for _, file := range stored { + storedByPath[filepath.Clean(file.Path)] = file + } + for _, file := range current { + currentByPath[filepath.Clean(file.Path)] = file + } + paths := make([]string, 0, len(storedByPath)+len(currentByPath)) + seen := make(map[string]struct{}, len(storedByPath)+len(currentByPath)) + for path := range storedByPath { + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + paths = append(paths, path) + } + for path := range currentByPath { + if _, ok := seen[path]; ok { + continue + } + seen[path] = struct{}{} + paths = append(paths, path) + } + sort.Strings(paths) + diffs := make([]string, 0, len(paths)) + for _, path := range paths { + storedFile, storedOK := storedByPath[path] + currentFile, currentOK := currentByPath[path] + switch { + case !storedOK: + diffs = append(diffs, fmt.Sprintf("%s added size=%d mtime=%d", path, currentFile.Size, currentFile.ModTime)) + case !currentOK: + diffs = append(diffs, fmt.Sprintf("%s removed size=%d mtime=%d", path, storedFile.Size, storedFile.ModTime)) + case storedFile != currentFile: + diffs = append(diffs, fmt.Sprintf("%s size:%d->%d mtime:%d->%d", path, storedFile.Size, currentFile.Size, storedFile.ModTime, currentFile.ModTime)) + } + } + return diffs +} + +func manifestNeedsLocalRefresh(stored []packageFingerprint, current []packageFingerprint) bool { + if len(stored) != len(current) { + return false + } + for i := range stored { + if stored[i].PkgPath != current[i].PkgPath { + return false + } + if stored[i].ContentHash == "" && current[i].ContentHash != "" { + return true + } + if len(stored[i].Dirs) == 0 && len(current[i].Dirs) > 0 { + return true + } + } + return false +} + +func packageDirectoryMetaChanged(fp packageFingerprint, storedFiles []string) ([]cacheFile, bool, error) { + dirs := packageFingerprintDirs(storedFiles) + if len(dirs) == 0 { + return nil, false, nil + } + current, err := buildCacheFiles(dirs) + if err != nil { + return nil, false, err + } + if len(fp.Dirs) != len(current) { + return current, true, nil + } + for i := range current { + if current[i] != fp.Dirs[i] { + return current, true, nil + } + } + return current, false, nil +} + func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { dirs := make(map[string]struct{}) old := make(map[string]struct{}, len(files)) diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go index 72b051b..2fdaa2b 100644 --- a/internal/wire/incremental_session.go +++ b/internal/wire/incremental_session.go @@ -46,7 +46,7 @@ func clearIncrementalSessions() { func sessionKey(wd string, env []string, tags string) string { var b strings.Builder - b.WriteString(filepath.Clean(wd)) + b.WriteString(packageCacheScope(wd)) b.WriteByte('\n') b.WriteString(tags) b.WriteByte('\n') diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go index 2930b37..934f637 100644 --- a/internal/wire/incremental_summary.go +++ b/internal/wire/incremental_summary.go @@ -99,7 +99,7 @@ func incrementalSummaryKey(wd string, tags string, pkgPath string) string { h := sha256.New() h.Write([]byte(incrementalSummaryVersion)) h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(wd))) + h.Write([]byte(packageCacheScope(wd))) h.Write([]byte{0}) h.Write([]byte(tags)) h.Write([]byte{0}) diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go index d74c0f8..37e27d9 100644 --- a/internal/wire/loader_test.go +++ b/internal/wire/loader_test.go @@ -220,7 +220,7 @@ func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { } } -func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *testing.T) { +func TestGenerateIncrementalBodyOnlyChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() t.Cleanup(func() { restoreCacheHooks(state) }) @@ -340,14 +340,253 @@ func TestGenerateIncrementalBodyOnlyChangeFallsBackToLoadAndReusesOutput(t *test if len(second) != 1 || len(second[0].Errs) > 0 { t.Fatalf("unexpected second Generate result: %+v", second) } - if !containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to re-load packages after body-only change, labels=%v", secondLabels) + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected second Generate to reuse preload manifest after body-only change, labels=%v", secondLabels) } if string(first[0].Content) != string(second[0].Content) { t.Fatal("expected body-only change to reuse identical generated output") } } +func TestGenerateIncrementalTouchedValidationCacheReusesSuccessfulValidation(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeBodyVariant := func(message string) { + t.Helper() + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return \"" + message + "\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + } + writeBodyVariant("a") + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeBodyVariant("b") + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected first body-only variant change to avoid generate.load, labels=%v", secondLabels) + } + if containsLabel(secondLabels, "incremental.preload_manifest.validate_touched_cache_hit") { + t.Fatalf("did not expect first body-only variant change to hit touched validation cache, labels=%v", secondLabels) + } + + writeBodyVariant("a") + third, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("third Generate returned errors: %v", errs) + } + if len(third) != 1 || len(third[0].Errs) > 0 { + t.Fatalf("unexpected third Generate result: %+v", third) + } + + writeBodyVariant("b") + + var fourthLabels []string + fourthCtx := WithTiming(ctx, func(label string, _ time.Duration) { + fourthLabels = append(fourthLabels, label) + }) + fourth, errs := Generate(fourthCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("fourth Generate returned errors: %v", errs) + } + if len(fourth) != 1 || len(fourth[0].Errs) > 0 { + t.Fatalf("unexpected fourth Generate result: %+v", fourth) + } + if containsLabel(fourthLabels, "generate.load") { + t.Fatalf("expected repeated body-only variant change to avoid generate.load, labels=%v", fourthLabels) + } + if !containsLabel(fourthLabels, "incremental.preload_manifest.validate_touched_cache_hit") { + t.Fatalf("expected repeated body-only variant change to hit touched validation cache, labels=%v", fourthLabels) + } + if string(first[0].Content) != string(fourth[0].Content) { + t.Fatal("expected repeated body-only variant change to reuse identical generated output") + } +} + +func TestGenerateIncrementalConstValueChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + env := append(os.Environ(), "GOWORK=off") + ctx := WithIncremental(context.Background(), true) + + first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("first Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected first Generate result: %+v", first) + } + + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + + var secondLabels []string + secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { + secondLabels = append(secondLabels, label) + }) + second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("second Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected second Generate result: %+v", second) + } + if containsLabel(secondLabels, "generate.load") { + t.Fatalf("expected const-value change to reuse preload manifest, labels=%v", secondLabels) + } + if string(first[0].Content) != string(second[0].Content) { + t.Fatal("expected const-value change to reuse identical generated output") + } +} + func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -453,6 +692,448 @@ func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t * } } +func TestGenerateIncrementalScenarioMatrix(t *testing.T) { + t.Parallel() + + type scenarioExpectation struct { + mode string + wantErr bool + wantSameOutput bool + } + + scenarios := []struct { + name string + apply func(t *testing.T, fx incrementalScenarioFixture) + want scenarioExpectation + }{ + { + name: "comment_only_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "// SQLText controls SQL highlighting in log output.", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "whitespace_only_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "", + "func New(msg string) *Foo {", + "", + "\treturn &Foo{Message: helper(msg)}", + "", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "function_body_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string {", + "\treturn helper(SQLText)", + "}", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "method_body_change_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func (f Foo) Summary() string {", + "\treturn helper(f.Message)", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: msg}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "const_value_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"blue\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "var_initializer_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 2", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "add_top_level_helper_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func NewTag() string { return \"tag\" }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "import_only_implementation_change_reuses_preload", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return fmt.Sprint(msg) }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "preload", wantSameOutput: true}, + }, + { + name: "signature_change_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 7", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func NewCount() int { return defaultCount }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string, count int) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: count}", + "}", + "", + }, "\n")) + writeFile(t, fx.wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, NewCount, New)", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: false}, + }, + { + name: "struct_field_addition_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct {", + "\tMessage string", + "\tCount int", + "}", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg), Count: defaultCount}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "interface_method_addition_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Fooer interface {", + "\tMessage() string", + "\tCount() int", + "}", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, + }, + { + name: "new_source_file_uses_local_fastpath", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.extraFile, strings.Join([]string{ + "package dep", + "", + "func NewTag() string { return \"tag\" }", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "fast", wantSameOutput: true}, + }, + { + name: "invalid_body_change_falls_back_and_errors", + apply: func(t *testing.T, fx incrementalScenarioFixture) { + writeFile(t, fx.depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return missing }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + }, + want: scenarioExpectation{mode: "generate_load", wantErr: true}, + }, + } + + for _, scenario := range scenarios { + scenario := scenario + t.Run(scenario.name, func(t *testing.T) { + fx := newIncrementalScenarioFixture(t) + + first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("baseline Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected baseline Generate result: %+v", first) + } + + scenario.apply(t, fx) + + var labels []string + timedCtx := WithTiming(fx.ctx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + second, errs := Generate(timedCtx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + + if scenario.want.wantErr { + if len(errs) == 0 { + t.Fatal("expected Generate to return errors") + } + if len(second) != 0 { + t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) + } + } else { + if len(errs) > 0 { + t.Fatalf("incremental Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected incremental Generate result: %+v", second) + } + } + + switch scenario.want.mode { + case "preload": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected preload reuse without generate.load, labels=%v", labels) + } + case "fast": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected fast incremental path without generate.load, labels=%v", labels) + } + case "local_fastpath": + if containsLabel(labels, "generate.load") { + t.Fatalf("expected local fast path without generate.load, labels=%v", labels) + } + if containsLabel(labels, "load.packages.lazy.load") { + t.Fatalf("expected local fast path to skip lazy load, labels=%v", labels) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected local fast path load, labels=%v", labels) + } + case "generate_load": + if !containsLabel(labels, "generate.load") { + t.Fatalf("expected generate.load fallback, labels=%v", labels) + } + default: + t.Fatalf("unknown expected mode %q", scenario.want.mode) + } + + if scenario.want.wantErr { + return + } + + normal, errs := Generate(context.Background(), fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("normal Generate returned errors after edit: %v", errs) + } + if len(normal) != 1 || len(normal[0].Errs) > 0 { + t.Fatalf("unexpected normal Generate result after edit: %+v", normal) + } + if second[0].OutputPath != normal[0].OutputPath { + t.Fatalf("output paths differ: incremental=%q normal=%q", second[0].OutputPath, normal[0].OutputPath) + } + if string(second[0].Content) != string(normal[0].Content) { + t.Fatalf("incremental output differs from normal output after %s", scenario.name) + } + if scenario.want.wantSameOutput && string(first[0].Content) != string(second[0].Content) { + t.Fatalf("expected generated output to stay unchanged for %s", scenario.name) + } + if !scenario.want.wantSameOutput && string(first[0].Content) == string(second[0].Content) { + t.Fatalf("expected generated output to change for %s", scenario.name) + } + }) + } +} + func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -933,6 +1614,155 @@ func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { } } +func TestLoadLocalPackagesForFastPathImportsUnchangedLocalDependencyFromLocalExport(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depPkgPath := "example.com/app/dep" + depExportPath := mustLocalExportPath(t, root, env, depPkgPath) + if _, err := os.Stat(depExportPath); err != nil { + t.Fatalf("expected local export artifact at %s: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + preloadState, ok := prepareIncrementalPreloadState(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) + if !ok || preloadState == nil || preloadState.manifest == nil { + t.Fatal("expected preload state after baseline incremental generate") + } + loaded, err := loadLocalPackagesForFastPath(context.Background(), root, "", "example.com/app/app", []string{"example.com/app/router"}, preloadState.currentLocal, preloadState.manifest.ExternalPkgs) + if err != nil { + t.Fatalf("loadLocalPackagesForFastPath returned error: %v", err) + } + if _, ok := loaded.loader.localExports[depPkgPath]; !ok { + t.Fatalf("expected %s to be a local export candidate", depPkgPath) + } + if _, ok := loaded.loader.sourcePkgs[depPkgPath]; ok { + t.Fatalf("did not expect %s to be source-loaded", depPkgPath) + } + typesPkg, err := loaded.loader.importPackage(depPkgPath) + if err != nil { + t.Fatalf("importPackage(%s) returned error: %v", depPkgPath, err) + } + if typesPkg == nil || !typesPkg.Complete() { + t.Fatalf("expected complete imported package for %s, got %#v", depPkgPath, typesPkg) + } + if loaded.loader.pkgs[depPkgPath] != nil { + t.Fatalf("expected %s to avoid source loading when local export artifact is present", depPkgPath) + } +} + +func TestGenerateIncrementalMissingLocalExportFallsBackSafely(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if err := os.Remove(depExportPath); err != nil { + t.Fatalf("Remove(%s) failed: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + var labels []string + timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected missing local export to stay on local fast path, labels=%v", labels) + } + refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if _, err := os.Stat(refreshedExportPath); err != nil { + t.Fatalf("expected local export artifact to be refreshed at %s: %v", refreshedExportPath, err) + } +} + +func TestGenerateIncrementalCorruptedLocalExportFallsBackSafely(t *testing.T) { + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + writeDepRouterModule(t, root, repoRoot) + + env := append(os.Environ(), "GOWORK=off") + incrementalCtx := WithIncremental(context.Background(), true) + + if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { + t.Fatalf("baseline incremental Generate returned errors: %v", errs) + } + + depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + if err := os.WriteFile(depExportPath, []byte("not-a-valid-export"), 0644); err != nil { + t.Fatalf("WriteFile(%s) failed: %v", depExportPath, err) + } + + mutateRouterModule(t, root) + + var labels []string + timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { + labels = append(labels, label) + }) + gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("Generate returned errors: %v", errs) + } + if len(gens) != 1 || len(gens[0].Errs) > 0 { + t.Fatalf("unexpected Generate results: %+v", gens) + } + if !containsLabel(labels, "incremental.local_fastpath.load") { + t.Fatalf("expected corrupted local export to stay on local fast path, labels=%v", labels) + } + refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") + data, err := os.ReadFile(refreshedExportPath) + if err != nil { + t.Fatalf("ReadFile(%s) failed: %v", refreshedExportPath, err) + } + if string(data) == "not-a-valid-export" { + t.Fatalf("expected corrupted local export artifact to be refreshed at %s", refreshedExportPath) + } +} + func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { lockCacheHooks(t) state := saveCacheHooks() @@ -1174,9 +2004,6 @@ func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) if len(errs) == 0 { t.Fatal("expected invalid incremental generate to return errors") } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected invalid incremental generate to stop before slow-path load, labels=%v", secondLabels) - } if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { t.Fatalf("expected fast-path type-check error, got %q", got) } @@ -1327,8 +2154,8 @@ func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { if string(third[0].Content) != string(normal[0].Content) { t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") } - if !containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected recovery run to fall back to normal load after invalidating stale manifest, labels=%v", thirdLabels) + if !containsLabel(thirdLabels, "incremental.local_fastpath.load") && !containsLabel(thirdLabels, "generate.load") { + t.Fatalf("expected recovery run to rebuild through local fast path or normal load, labels=%v", thirdLabels) } } @@ -1466,6 +2293,55 @@ func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t } } +func TestGenerateIncrementalPreloadHitRefreshesMissingContentHashes(t *testing.T) { + fx := newIncrementalScenarioFixture(t) + + first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("baseline Generate returned errors: %v", errs) + } + if len(first) != 1 || len(first[0].Errs) > 0 { + t.Fatalf("unexpected baseline Generate result: %+v", first) + } + + selectorKey := incrementalManifestSelectorKey(fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + manifest, ok := readIncrementalManifest(selectorKey) + if !ok { + t.Fatal("expected incremental manifest after baseline generate") + } + if len(manifest.LocalPackages) == 0 { + t.Fatal("expected local packages in incremental manifest") + } + + stale := *manifest + stale.LocalPackages = append([]packageFingerprint(nil), manifest.LocalPackages...) + for i := range stale.LocalPackages { + stale.LocalPackages[i].ContentHash = "" + stale.LocalPackages[i].Dirs = nil + } + writeIncrementalManifestFile(selectorKey, &stale) + writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, stale.LocalPackages), &stale) + + second, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if len(errs) > 0 { + t.Fatalf("refresh Generate returned errors: %v", errs) + } + if len(second) != 1 || len(second[0].Errs) > 0 { + t.Fatalf("unexpected refresh Generate result: %+v", second) + } + + preloadState, ok := prepareIncrementalPreloadState(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) + if !ok { + t.Fatal("expected preload state after manifest refresh") + } + if !preloadState.valid { + t.Fatalf("expected refreshed preload state to be valid, reason=%s", preloadState.reason) + } + if len(preloadState.touched) != 0 { + t.Fatalf("expected refreshed preload state to have no touched packages, got %v", preloadState.touched) + } +} + func containsLabel(labels []string, want string) bool { for _, label := range labels { if label == want { @@ -1475,6 +2351,96 @@ func containsLabel(labels []string, want string) bool { return false } +type incrementalScenarioFixture struct { + root string + env []string + ctx context.Context + depFile string + wireFile string + extraFile string +} + +func newIncrementalScenarioFixture(t *testing.T) incrementalScenarioFixture { + t.Helper() + + lockCacheHooks(t) + state := saveCacheHooks() + t.Cleanup(func() { restoreCacheHooks(state) }) + + cacheRoot := t.TempDir() + osTempDir = func() string { return cacheRoot } + + repoRoot := mustRepoRoot(t) + root := t.TempDir() + + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *dep.Foo {", + "\twire.Build(dep.NewSet)", + "\treturn nil", + "}", + "", + }, "\n")) + + depFile := filepath.Join(root, "dep", "dep.go") + writeFile(t, depFile, strings.Join([]string{ + "package dep", + "", + "const SQLText = \"green\"", + "", + "var defaultCount = 1", + "", + "type Foo struct { Message string }", + "", + "func NewMessage() string { return SQLText }", + "", + "func helper(msg string) string { return msg }", + "", + "func New(msg string) *Foo {", + "\treturn &Foo{Message: helper(msg)}", + "}", + "", + }, "\n")) + + wireFile := filepath.Join(root, "dep", "wire.go") + writeFile(t, wireFile, strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var NewSet = wire.NewSet(NewMessage, New)", + "", + }, "\n")) + + return incrementalScenarioFixture{ + root: root, + env: append(os.Environ(), "GOWORK=off"), + ctx: WithIncremental(context.Background(), true), + depFile: depFile, + wireFile: wireFile, + extraFile: filepath.Join(root, "dep", "extra.go"), + } +} + func mustRepoRoot(t *testing.T) string { t.Helper() wd, err := os.Getwd() @@ -1488,6 +2454,137 @@ func mustRepoRoot(t *testing.T) string { return repoRoot } +func writeDepRouterModule(t *testing.T, root string, repoRoot string) { + t.Helper() + writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/goforj/wire v0.0.0", + "replace github.com/goforj/wire => " + repoRoot, + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "//go:build wireinject", + "// +build wireinject", + "", + "package app", + "", + "import (", + "\t\"example.com/app/dep\"", + "\t\"example.com/app/router\"", + "\t\"github.com/goforj/wire\"", + ")", + "", + "func Init() *router.Routes {", + "\twire.Build(dep.Set, router.Set)", + "\treturn nil", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Controller struct { Message string }", + "", + "func NewMessage() string { return \"ok\" }", + "", + "func NewController(msg string) *Controller {", + "\treturn &Controller{Message: msg}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ + "package dep", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewMessage, NewController)", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct { Controller *dep.Controller }", + "", + "func ProvideRoutes(controller *dep.Controller) *Routes {", + "\treturn &Routes{Controller: controller}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(ProvideRoutes)", + "", + }, "\n")) +} + +func mutateRouterModule(t *testing.T, root string) { + t.Helper() + writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ + "package router", + "", + "import \"example.com/app/dep\"", + "", + "type Routes struct {", + "\tController *dep.Controller", + "\tVersion int", + "}", + "", + "func NewVersion() int {", + "\treturn 2", + "}", + "", + "func ProvideRoutes(controller *dep.Controller, version int) *Routes {", + "\treturn &Routes{Controller: controller, Version: version}", + "}", + "", + }, "\n")) + + writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ + "package router", + "", + "import \"github.com/goforj/wire\"", + "", + "var Set = wire.NewSet(NewVersion, ProvideRoutes)", + "", + }, "\n")) +} + +func mustLocalExportPath(t *testing.T, root string, env []string, pkgPath string) string { + t.Helper() + pkgs, loader, errs := load(context.Background(), root, env, "", []string{"./app"}) + if len(errs) > 0 { + t.Fatalf("load returned errors: %v", errs) + } + if loader == nil { + t.Fatal("load returned nil loader") + } + if _, errs := loader.load("example.com/app/app"); len(errs) > 0 { + t.Fatalf("lazy load returned errors: %v", errs) + } + snapshot := buildIncrementalManifestSnapshotFromPackages(root, "", incrementalManifestPackages(pkgs, loader)) + if snapshot == nil || snapshot.fingerprints[pkgPath] == nil { + t.Fatalf("missing fingerprint for %s", pkgPath) + } + path := localExportPathForFingerprint(root, "", snapshot.fingerprints[pkgPath]) + if path == "" { + t.Fatalf("missing local export path for %s", pkgPath) + } + return path +} + func writeFile(t *testing.T, path string, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { diff --git a/internal/wire/local_export.go b/internal/wire/local_export.go new file mode 100644 index 0000000..f83ed7b --- /dev/null +++ b/internal/wire/local_export.go @@ -0,0 +1,97 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package wire + +import ( + "crypto/sha256" + "fmt" + "go/token" + "go/types" + "path/filepath" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +const localExportVersion = "wire-local-export-v1" + +func localExportKey(wd string, tags string, pkgPath string, shapeHash string) string { + sum := sha256.Sum256([]byte(localExportVersion + "\x00" + packageCacheScope(wd) + "\x00" + tags + "\x00" + pkgPath + "\x00" + shapeHash)) + return fmt.Sprintf("%x", sum[:]) +} + +func localExportPath(key string) string { + return filepath.Join(cacheDir(), key+".iexp") +} + +func localExportPathForFingerprint(wd string, tags string, fp *packageFingerprint) string { + if fp == nil || fp.PkgPath == "" || fp.ShapeHash == "" { + return "" + } + return localExportPath(localExportKey(wd, tags, fp.PkgPath, fp.ShapeHash)) +} + +func localExportExists(wd string, tags string, fp *packageFingerprint) bool { + path := localExportPathForFingerprint(wd, tags, fp) + if path == "" { + return false + } + _, err := osStat(path) + return err == nil +} + +func writeLocalPackageExports(wd string, tags string, pkgs []*packages.Package, fps map[string]*packageFingerprint) { + if len(pkgs) == 0 || len(fps) == 0 { + return + } + moduleRoot := findModuleRoot(wd) + for _, pkg := range pkgs { + if pkg == nil || pkg.Types == nil || pkg.PkgPath == "" { + continue + } + if classifyPackageLocation(moduleRoot, pkg) != "local" { + continue + } + fp := fps[pkg.PkgPath] + path := localExportPathForFingerprint(wd, tags, fp) + if path == "" { + continue + } + writeLocalPackageExportFile(path, pkg.Fset, pkg.Types) + } +} + +func writeLocalPackageExportFile(path string, fset *token.FileSet, pkg *types.Package) { + if path == "" || fset == nil || pkg == nil { + return + } + dir := cacheDir() + if err := osMkdirAll(dir, 0755); err != nil { + return + } + tmp, err := osCreateTemp(dir, filepath.Base(path)+".tmp-") + if err != nil { + return + } + writeErr := gcexportdata.Write(tmp, fset, pkg) + closeErr := tmp.Close() + if writeErr != nil || closeErr != nil { + osRemove(tmp.Name()) + return + } + if err := osRename(tmp.Name(), path); err != nil { + osRemove(tmp.Name()) + } +} diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go index 04d9cb8..89ea402 100644 --- a/internal/wire/local_fastpath.go +++ b/internal/wire/local_fastpath.go @@ -31,6 +31,7 @@ import ( "strings" "time" + "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" ) @@ -96,8 +97,10 @@ func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, p for _, path := range snapshot.changed { changedSet[path] = struct{}{} } + currentPackages := loaded.currentPackages() writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeIncrementalPackageSummariesWithSummary(loader, loaded.allPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) + writeLocalPackageExports(wd, opts.Tags, currentPackages, loaded.fingerprints) + writeIncrementalPackageSummariesWithSummary(loader, currentPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) @@ -235,6 +238,21 @@ type localFastPathLoaded struct { loader *localFastPathLoader } +func (l *localFastPathLoaded) currentPackages() []*packages.Package { + if l == nil { + return nil + } + if l.loader == nil || len(l.loader.pkgs) == 0 { + return l.allPackages + } + all := make([]*packages.Package, 0, len(l.loader.pkgs)) + for _, pkg := range l.loader.pkgs { + all = append(all, pkg) + } + sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) + return all +} + type localFastPathLoader struct { ctx context.Context wd string @@ -249,10 +267,25 @@ type localFastPathLoader struct { pkgs map[string]*packages.Package imported map[string]*types.Package externalMeta map[string]externalPackageExport + localExports map[string]string externalImp types.Importer + externalFallback types.Importer } func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { + return loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, changed, current, external, false) +} + +func validateTouchedPackagesFastPath(ctx context.Context, wd string, tags string, touched []string, current []packageFingerprint, external []externalPackageExport) error { + if len(touched) == 0 { + return nil + } + rootPkgPath := touched[0] + _, err := loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, touched, current, external, true) + return err +} + +func loadLocalPackagesForFastPathMode(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport, validationOnly bool) (*localFastPathLoaded, error) { meta := fingerprintsFromSlice(current) if len(meta) == 0 { return nil, fmt.Errorf("no local fingerprints") @@ -265,6 +298,9 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r if item.PkgPath == "" || item.ExportFile == "" { continue } + if meta[item.PkgPath] != nil { + continue + } externalMeta[item.PkgPath] = item } loader := &localFastPathLoader{ @@ -281,12 +317,18 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r pkgs: make(map[string]*packages.Package, len(meta)), imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), externalMeta: externalMeta, + localExports: make(map[string]string), } for _, path := range changed { loader.changedPkgs[path] = struct{}{} } - loader.markSourceClosure() - candidates := make(map[string]*packageSummary) + if validationOnly { + for path := range loader.changedPkgs { + loader.sourcePkgs[path] = struct{}{} + } + } else { + loader.markSourceClosure() + } for path, fp := range meta { if path == rootPkgPath { continue @@ -294,7 +336,19 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r if _, changed := loader.changedPkgs[path]; changed { continue } - if _, ok := externalMeta[path]; !ok { + if _, ok := loader.sourcePkgs[path]; ok { + continue + } + if exportPath := localExportPathForFingerprint(wd, tags, fp); exportPath != "" && localExportExists(wd, tags, fp) { + loader.localExports[path] = exportPath + } + } + candidates := make(map[string]*packageSummary) + for path, fp := range meta { + if path == rootPkgPath { + continue + } + if _, changed := loader.changedPkgs[path]; changed { continue } summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) @@ -305,9 +359,24 @@ func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, r } loader.summaries = filterSupportedPackageSummaries(candidates) loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) - root, err := loader.load(rootPkgPath) - if err != nil { - return nil, err + loader.externalFallback = importerpkg.ForCompiler(loader.fset, "gc", nil) + var root *packages.Package + if validationOnly { + for _, path := range changed { + pkg, err := loader.load(path) + if err != nil { + return nil, err + } + if root == nil { + root = pkg + } + } + } else { + var err error + root, err = loader.load(rootPkgPath) + if err != nil { + return nil, err + } } all := make([]*packages.Package, 0, len(loader.pkgs)) for _, pkg := range loader.pkgs { @@ -341,6 +410,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { mode |= parser.ParseComments } syntax := make([]*ast.File, 0, len(files)) + parseStart := time.Now() for _, name := range files { file, err := l.parseFileForFastPath(name, mode, pkgPath) if err != nil { @@ -348,6 +418,7 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { } syntax = append(syntax, file) } + logTiming(l.ctx, "incremental.local_fastpath.parse", parseStart) if len(syntax) == 0 { return nil, fmt.Errorf("package %s parsed no files", pkgPath) } @@ -376,7 +447,9 @@ func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, } + typecheckStart := time.Now() checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) + logTiming(l.ctx, "incremental.local_fastpath.typecheck", typecheckStart) if checkedPkg != nil { pkg.Types = checkedPkg l.imported[pkgPath] = checkedPkg @@ -422,6 +495,7 @@ func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { syntax := make([]*ast.File, 0, len(files)) + parseStart := time.Now() for _, name := range files { file, err := parser.ParseFile(l.fset, name, nil, mode) if err != nil { @@ -429,6 +503,7 @@ func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files [ } syntax = append(syntax, file) } + logTiming(l.ctx, "incremental.local_fastpath.parse_retry", parseStart) pkg.Syntax = syntax pkg.Errors = nil pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) @@ -442,7 +517,9 @@ func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files [ pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) }, } + typecheckStart := time.Now() checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) + logTiming(l.ctx, "incremental.local_fastpath.typecheck_retry", typecheckStart) if checkedPkg != nil { pkg.Types = checkedPkg l.imported[pkgPath] = checkedPkg @@ -471,13 +548,29 @@ func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, er func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { if l.shouldImportFromExport(path) { - return l.importExportPackage(path) + pkg, err := l.importExportPackage(path) + if err == nil { + return pkg, nil + } + // Cached local export artifacts are an optimization only. If one is + // missing or corrupted, fall back to source loading for correctness. + if _, ok := l.localExports[path]; ok && l.meta[path] != nil { + delete(l.localExports, path) + pkg, loadErr := l.load(path) + if loadErr == nil { + l.refreshLocalExport(path, pkg) + return pkg.Types, nil + } + return nil, loadErr + } + return nil, err } if l.meta[path] != nil { pkg, err := l.load(path) if err != nil { return nil, err } + l.refreshLocalExport(path, pkg) return pkg.Types, nil } if l.externalImp == nil { @@ -536,7 +629,20 @@ func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, if l == nil { return nil, fmt.Errorf("missing local fast path loader") } - if pkg := l.imported[path]; pkg != nil { + if pkg := l.imported[path]; pkg != nil && pkg.Complete() { + return pkg, nil + } + if exportPath := l.localExports[path]; exportPath != "" { + f, err := os.Open(exportPath) + if err != nil { + return nil, err + } + defer f.Close() + pkg, err := gcexportdata.Read(f, l.fset, l.imported, path) + if err != nil { + return nil, err + } + l.imported[path] = pkg return pkg, nil } if l.externalImp == nil { @@ -544,6 +650,13 @@ func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, } pkg, err := l.externalImp.Import(path) if err != nil { + if l.externalFallback != nil && strings.Contains(err.Error(), "missing external export data for ") { + pkg, fallbackErr := l.externalFallback.Import(path) + if fallbackErr == nil { + l.imported[path] = pkg + return pkg, nil + } + } return nil, err } l.imported[path] = pkg @@ -557,10 +670,26 @@ func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { if _, source := l.sourcePkgs[pkgPath]; source { return false } - _, ok := l.summaries[pkgPath] + if _, ok := l.localExports[pkgPath]; ok { + return true + } + _, ok := l.externalMeta[pkgPath] return ok } +func (l *localFastPathLoader) refreshLocalExport(pkgPath string, pkg *packages.Package) { + if l == nil || pkg == nil || pkg.Fset == nil || pkg.Types == nil { + return + } + fp := l.meta[pkgPath] + exportPath := localExportPathForFingerprint(l.wd, l.tags, fp) + if exportPath == "" { + return + } + writeLocalPackageExportFile(exportPath, pkg.Fset, pkg.Types) + l.localExports[pkgPath] = exportPath +} + func (l *localFastPathLoader) markSourceClosure() { if l == nil { return @@ -802,14 +931,15 @@ func writeIncrementalManifestFromState(wd string, env []string, patterns []strin if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { return } + scope := runCacheScope(wd, patterns) manifest := &incrementalManifest{ Version: incrementalManifestVersion, - WD: filepath.Clean(wd), + WD: scope, Tags: opts.Tags, Prefix: opts.PrefixOutputFile, HeaderHash: headerHash(opts.Header), EnvHash: envHash(env), - Patterns: sortedStrings(patterns), + Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), LocalPackages: snapshotPackageFingerprints(snapshot), ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), @@ -841,7 +971,7 @@ func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, f } graph := &incrementalGraph{ Version: incrementalGraphVersion, - WD: filepath.Clean(wd), + WD: packageCacheScope(wd), Tags: tags, Roots: append([]string(nil), roots...), LocalReverse: make(map[string][]string), diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 1b6140f..24ca575 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -154,9 +154,13 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) if snapshot != nil { + writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), snapshot.fingerprints) writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) + loader.fingerprints = snapshot } + writeIncrementalPackageSummaries(loader, pkgs) } else { + writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), loader.fingerprints.fingerprints) writeIncrementalPackageSummaries(loader, pkgs) writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) } diff --git a/scripts/incremental-scenarios.sh b/scripts/incremental-scenarios.sh new file mode 100755 index 0000000..b59e970 --- /dev/null +++ b/scripts/incremental-scenarios.sh @@ -0,0 +1,137 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/incremental-scenarios.sh test + scripts/incremental-scenarios.sh matrix + scripts/incremental-scenarios.sh table + scripts/incremental-scenarios.sh budgets + scripts/incremental-scenarios.sh bench + scripts/incremental-scenarios.sh large-table + scripts/incremental-scenarios.sh large-breakdown + scripts/incremental-scenarios.sh report + scripts/incremental-scenarios.sh all + +Commands: + test Run the full internal/wire test suite. + matrix Run the incremental scenario matrix correctness test. + table Print the incremental scenario timing table. + budgets Enforce the incremental scenario performance budgets. + bench Run the incremental scenario benchmark suite. + large-table Print the large-repo comparison timing table. + large-breakdown Print the large-repo shape-change breakdown table. + report Run the main timing report: scenario table, budgets, and large-repo table. + all Run matrix, table, budgets, and the large-repo table in sequence. +EOF +} + +print_section() { + local title="$1" + printf '\n== %s ==\n' "$title" +} + +print_test_table() { + local output_file="$1" + awk ' + /^\+[-+]+\+$/ { in_table=1 } + in_table && !/^--- PASS:/ && !/^PASS$/ && !/^ok[[:space:]]/ { print } + /^--- PASS:/ && in_table { exit } + ' "$output_file" +} + +run_test_table() { + local env_var="$1" + local test_name="$2" + local output_file + output_file="$(mktemp)" + env "$env_var"=1 go test ./internal/wire -run "$test_name" -count=1 -v >"$output_file" + print_test_table "$output_file" + rm -f "$output_file" +} + +run_test() { + go test ./internal/wire -count=1 +} + +run_matrix() { + go test ./internal/wire -run TestGenerateIncrementalScenarioMatrix -count=1 +} + +run_table() { + run_test_table WIRE_BENCH_SCENARIOS TestPrintIncrementalScenarioBenchmarkTable +} + +run_budgets() { + WIRE_PERF_BUDGETS=1 go test ./internal/wire -run TestIncrementalScenarioPerformanceBudgets -count=1 >/dev/null + echo "PASS" +} + +run_bench() { + go test ./internal/wire -run '^$' -bench BenchmarkGenerateIncrementalScenarioMatrix -benchmem -count=1 +} + +run_large_table() { + run_test_table WIRE_BENCH_TABLE TestPrintLargeRepoBenchmarkComparisonTable +} + +run_large_breakdown() { + run_test_table WIRE_BENCH_BREAKDOWN TestPrintLargeRepoShapeChangeBreakdownTable +} + +run_report() { + print_section "Scenario Timing Table" + run_table + print_section "Scenario Performance Budgets" + run_budgets + print_section "Large Repo Comparison Table" + run_large_table +} + +cmd="${1:-}" +case "$cmd" in + test) + run_test + ;; + matrix) + run_matrix + ;; + table) + run_table + ;; + budgets) + run_budgets + ;; + bench) + run_bench + ;; + large-table) + run_large_table + ;; + large-breakdown) + run_large_breakdown + ;; + report) + run_report + ;; + all) + run_matrix + run_report + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: $cmd" >&2 + usage >&2 + exit 1 + ;; +esac From d3486f551be2d63e51dc9bfc04440e91bb942bdc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 03:15:56 -0500 Subject: [PATCH 08/82] feat: custom loader initial --- cmd/wire/cache_cmd.go | 67 - cmd/wire/check_cmd.go | 3 - cmd/wire/diff_cmd.go | 3 - cmd/wire/gen_cmd.go | 3 - cmd/wire/incremental_flag.go | 60 - cmd/wire/main.go | 9 +- cmd/wire/show_cmd.go | 3 - cmd/wire/watch_cmd.go | 3 - internal/loader/custom.go | 1014 +++++++ internal/loader/discovery.go | 94 + internal/loader/fallback.go | 224 ++ internal/loader/loader.go | 135 + internal/loader/loader_test.go | 821 ++++++ internal/loader/mode.go | 38 + internal/loader/timing.go | 41 + internal/wire/cache_bypass.go | 17 - internal/wire/cache_coverage_test.go | 1099 ------- internal/wire/cache_generate_test.go | 100 - internal/wire/cache_key.go | 352 --- internal/wire/cache_manifest.go | 393 --- internal/wire/cache_scope.go | 69 - internal/wire/cache_scope_test.go | 59 - internal/wire/cache_store.go | 77 - internal/wire/cache_test.go | 387 --- internal/wire/generate_package.go | 126 - internal/wire/generate_package_test.go | 137 - internal/wire/incremental.go | 85 - internal/wire/incremental_bench_test.go | 1495 ---------- internal/wire/incremental_fingerprint.go | 674 ----- internal/wire/incremental_fingerprint_test.go | 142 - internal/wire/incremental_graph.go | 306 -- internal/wire/incremental_graph_test.go | 97 - internal/wire/incremental_manifest.go | 1158 -------- internal/wire/incremental_session.go | 102 - internal/wire/incremental_summary.go | 656 ----- internal/wire/incremental_summary_test.go | 295 -- internal/wire/incremental_test.go | 65 - internal/wire/load_debug.go | 31 +- internal/wire/loader_test.go | 2596 ----------------- internal/wire/loader_timing_bridge.go | 17 + .../{cache_hooks.go => loader_validation.go} | 34 +- internal/wire/local_export.go | 97 - internal/wire/local_fastpath.go | 1031 ------- internal/wire/parse.go | 146 +- internal/wire/parse_coverage_test.go | 12 +- internal/wire/parser_lazy_loader.go | 188 -- internal/wire/parser_lazy_loader_test.go | 204 -- internal/wire/summary_provider_resolver.go | 223 -- internal/wire/wire.go | 149 +- internal/wire/wire_test.go | 121 +- 50 files changed, 2645 insertions(+), 12613 deletions(-) delete mode 100644 cmd/wire/cache_cmd.go delete mode 100644 cmd/wire/incremental_flag.go create mode 100644 internal/loader/custom.go create mode 100644 internal/loader/discovery.go create mode 100644 internal/loader/fallback.go create mode 100644 internal/loader/loader.go create mode 100644 internal/loader/loader_test.go create mode 100644 internal/loader/mode.go create mode 100644 internal/loader/timing.go delete mode 100644 internal/wire/cache_bypass.go delete mode 100644 internal/wire/cache_coverage_test.go delete mode 100644 internal/wire/cache_generate_test.go delete mode 100644 internal/wire/cache_key.go delete mode 100644 internal/wire/cache_manifest.go delete mode 100644 internal/wire/cache_scope.go delete mode 100644 internal/wire/cache_scope_test.go delete mode 100644 internal/wire/cache_store.go delete mode 100644 internal/wire/cache_test.go delete mode 100644 internal/wire/generate_package.go delete mode 100644 internal/wire/generate_package_test.go delete mode 100644 internal/wire/incremental.go delete mode 100644 internal/wire/incremental_bench_test.go delete mode 100644 internal/wire/incremental_fingerprint.go delete mode 100644 internal/wire/incremental_fingerprint_test.go delete mode 100644 internal/wire/incremental_graph.go delete mode 100644 internal/wire/incremental_graph_test.go delete mode 100644 internal/wire/incremental_manifest.go delete mode 100644 internal/wire/incremental_session.go delete mode 100644 internal/wire/incremental_summary.go delete mode 100644 internal/wire/incremental_summary_test.go delete mode 100644 internal/wire/incremental_test.go delete mode 100644 internal/wire/loader_test.go create mode 100644 internal/wire/loader_timing_bridge.go rename internal/wire/{cache_hooks.go => loader_validation.go} (50%) delete mode 100644 internal/wire/local_export.go delete mode 100644 internal/wire/local_fastpath.go delete mode 100644 internal/wire/parser_lazy_loader.go delete mode 100644 internal/wire/parser_lazy_loader_test.go delete mode 100644 internal/wire/summary_provider_resolver.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go deleted file mode 100644 index f34d381..0000000 --- a/cmd/wire/cache_cmd.go +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "flag" - "fmt" - "log" - - "github.com/goforj/wire/internal/wire" - "github.com/google/subcommands" -) - -type cacheCmd struct { - clear bool -} - -// Name returns the subcommand name. -func (*cacheCmd) Name() string { return "cache" } - -// Synopsis returns a short summary of the subcommand. -func (*cacheCmd) Synopsis() string { - return "inspect or clear the wire cache" -} - -// Usage returns the help text for the subcommand. -func (*cacheCmd) Usage() string { - return `cache [-clear|clear] - - By default, prints the cache directory. With -clear or clear, removes all cache files. -` -} - -// SetFlags registers flags for the subcommand. -func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { - f.BoolVar(&cmd.clear, "clear", false, "remove all cached data") -} - -// Execute runs the subcommand. -func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { - if f.NArg() > 0 && f.Arg(0) == "clear" { - cmd.clear = true - } - if cmd.clear { - if err := wire.ClearCache(); err != nil { - log.Printf("failed to clear cache: %v\n", err) - return subcommands.ExitFailure - } - log.Printf("cleared cache at %s\n", wire.CacheDir()) - return subcommands.ExitSuccess - } - fmt.Println(wire.CacheDir()) - return subcommands.ExitSuccess -} diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 71872d9..897bec2 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -27,7 +27,6 @@ import ( type checkCmd struct { tags string - incremental optionalBoolFlag profile profileFlags } @@ -53,7 +52,6 @@ func (*checkCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *checkCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -67,7 +65,6 @@ func (cmd *checkCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index c7facca..5aad2f1 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -31,7 +31,6 @@ import ( type diffCmd struct { headerFile string tags string - incremental optionalBoolFlag profile profileFlags } @@ -61,7 +60,6 @@ func (*diffCmd) Usage() string { func (cmd *diffCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -79,7 +77,6 @@ func (cmd *diffCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index e98556f..aceefee 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -29,7 +29,6 @@ type genCmd struct { headerFile string prefixFileName string tags string - incremental optionalBoolFlag profile profileFlags } @@ -56,7 +55,6 @@ func (cmd *genCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -70,7 +68,6 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/incremental_flag.go b/cmd/wire/incremental_flag.go deleted file mode 100644 index 2962128..0000000 --- a/cmd/wire/incremental_flag.go +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package main - -import ( - "context" - "flag" - "strconv" - - "github.com/goforj/wire/internal/wire" -) - -type optionalBoolFlag struct { - value bool - set bool -} - -func (f *optionalBoolFlag) String() string { - if f == nil { - return "" - } - return strconv.FormatBool(f.value) -} - -func (f *optionalBoolFlag) Set(s string) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - f.value = v - f.set = true - return nil -} - -func (f *optionalBoolFlag) IsBoolFlag() bool { - return true -} - -func (f *optionalBoolFlag) apply(ctx context.Context) context.Context { - if f == nil || !f.set { - return ctx - } - return wire.WithIncremental(ctx, f.value) -} - -func addIncrementalFlag(f *optionalBoolFlag, fs *flag.FlagSet) { - fs.Var(f, "incremental", "enable the incremental engine (overrides "+wire.IncrementalEnvVar+")") -} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index efaf767..4426ee1 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -35,8 +35,6 @@ import ( "github.com/google/subcommands" ) -var topLevelIncremental optionalBoolFlag - const ( ansiRed = "\033[1;31m" ansiGreen = "\033[1;32m" @@ -52,12 +50,10 @@ func main() { subcommands.Register(subcommands.FlagsCommand(), "") subcommands.Register(subcommands.HelpCommand(), "") subcommands.Register(&checkCmd{}, "") - subcommands.Register(&cacheCmd{}, "") subcommands.Register(&diffCmd{}, "") subcommands.Register(&genCmd{}, "") subcommands.Register(&watchCmd{}, "") subcommands.Register(&showCmd{}, "") - addIncrementalFlag(&topLevelIncremental, flag.CommandLine) flag.Parse() // Initialize the default logger to log to stderr. @@ -74,7 +70,6 @@ func main() { "help": true, // builtin "flags": true, // builtin "check": true, - "cache": true, "diff": true, "gen": true, "serve": true, @@ -84,9 +79,9 @@ func main() { // Default to running the "gen" command. if args := flag.Args(); len(args) == 0 || !allCmds[args[0]] { genCmd := &genCmd{} - os.Exit(int(genCmd.Execute(topLevelIncremental.apply(context.Background()), flag.CommandLine))) + os.Exit(int(genCmd.Execute(context.Background(), flag.CommandLine))) } - os.Exit(int(subcommands.Execute(topLevelIncremental.apply(context.Background())))) + os.Exit(int(subcommands.Execute(context.Background()))) } // installStackDumper registers signal handlers to dump goroutine stacks. diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 1313ade..10c737f 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -35,7 +35,6 @@ import ( type showCmd struct { tags string - incremental optionalBoolFlag profile profileFlags } @@ -63,7 +62,6 @@ func (*showCmd) Usage() string { // SetFlags registers flags for the subcommand. func (cmd *showCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) cmd.profile.addFlags(f) } @@ -77,7 +75,6 @@ func (cmd *showCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interf defer stop() totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) wd, err := os.Getwd() if err != nil { diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index cb1b31b..ebdfa0e 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -36,7 +36,6 @@ type watchCmd struct { headerFile string prefixFileName string tags string - incremental optionalBoolFlag profile profileFlags pollInterval time.Duration rescanInterval time.Duration @@ -64,7 +63,6 @@ func (cmd *watchCmd) SetFlags(f *flag.FlagSet) { f.StringVar(&cmd.headerFile, "header_file", "", "path to file to insert as a header in wire_gen.go") f.StringVar(&cmd.prefixFileName, "output_file_prefix", "", "string to prepend to output file names.") f.StringVar(&cmd.tags, "tags", "", "append build tags to the default wirebuild") - addIncrementalFlag(&cmd.incremental, f) f.DurationVar(&cmd.pollInterval, "poll_interval", 250*time.Millisecond, "interval between file stat checks") f.DurationVar(&cmd.rescanInterval, "rescan_interval", 2*time.Second, "interval to rescan for new or removed Go files") cmd.profile.addFlags(f) @@ -79,7 +77,6 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } defer stop() ctx = withTiming(ctx, cmd.profile.timings) - ctx = cmd.incremental.apply(ctx) if cmd.pollInterval <= 0 { log.Println("poll_interval must be greater than zero") diff --git a/internal/loader/custom.go b/internal/loader/custom.go new file mode 100644 index 0000000..ffa2d48 --- /dev/null +++ b/internal/loader/custom.go @@ -0,0 +1,1014 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "fmt" + "go/ast" + importerpkg "go/importer" + "go/parser" + "go/scanner" + "go/token" + "go/types" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "sort" + "strings" + "time" + + "golang.org/x/tools/go/gcexportdata" + "golang.org/x/tools/go/packages" +) + +type unsupportedError struct { + reason string +} + +func (e unsupportedError) Error() string { return e.reason } + +type packageMeta struct { + ImportPath string + Name string + Dir string + DepOnly bool + Export string + GoFiles []string + CompiledGoFiles []string + Imports []string + ImportMap map[string]string + Error *goListError +} + +type goListError struct { + Err string +} + +type customValidator struct { + fset *token.FileSet + meta map[string]*packageMeta + touched map[string]struct{} + packages map[string]*types.Package + importer types.Importer + loading map[string]bool +} + +type customTypedGraphLoader struct { + workspace string + ctx context.Context + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + stats typedLoadStats +} + +type typedLoadStats struct { + read time.Duration + parse time.Duration + typecheck time.Duration + localRead time.Duration + externalRead time.Duration + localParse time.Duration + externalParse time.Duration + localTypecheck time.Duration + externalTypecheck time.Duration + filesRead int + packages int + localPackages int + externalPackages int + localFilesRead int + externalFilesRead int +} + +func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + if len(req.Touched) == 0 { + return &TouchedValidationResult{Backend: ModeCustom}, nil + } + meta, err := discoverTouchedMetadata(ctx, req) + if err != nil { + return nil, err + } + validator := &customValidator{ + fset: token.NewFileSet(), + meta: meta, + touched: make(map[string]struct{}, len(req.Touched)), + packages: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + for _, path := range req.Touched { + if !metadataMatchesFingerprint(path, meta, req.Local) { + return nil, unsupportedError{reason: "metadata fingerprint mismatch"} + } + validator.touched[path] = struct{}{} + } + out := make([]*packages.Package, 0, len(req.Touched)) + for _, path := range req.Touched { + pkg, err := validator.validatePackage(path) + if err != nil { + return nil, err + } + out = append(out, pkg) + } + return &TouchedValidationResult{ + Packages: out, + Backend: ModeCustom, + }, nil +} + +func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + discoveryStart := time.Now() + meta, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: req.NeedDeps, + }) + if err != nil { + return nil, err + } + logTiming(ctx, "loader.custom.root.discovery", discoveryStart) + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + pkgs := make(map[string]*packages.Package, len(meta)) + for path, m := range meta { + pkgs[path] = &packages.Package{ + ID: m.ImportPath, + Name: m.Name, + PkgPath: m.ImportPath, + GoFiles: append([]string(nil), metaFiles(m)...), + CompiledGoFiles: append([]string(nil), metaFiles(m)...), + Imports: make(map[string]*packages.Package), + } + if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { + pkgs[path].Errors = append(pkgs[path].Errors, packages.Error{ + Pos: "-", + Msg: m.Error.Err, + Kind: packages.ListError, + }) + } + } + for path, m := range meta { + pkg := pkgs[path] + for _, imp := range m.Imports { + target := imp + if mapped := m.ImportMap[imp]; mapped != "" { + target = mapped + } + if dep := pkgs[target]; dep != nil { + pkg.Imports[imp] = dep + } + } + } + roots := make([]*packages.Package, 0, len(req.Patterns)) + for _, m := range meta { + if m.DepOnly { + continue + } + if pkg := pkgs[m.ImportPath]; pkg != nil { + roots = append(roots, pkg) + } + } + if len(roots) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return &RootLoadResult{ + Packages: roots, + Backend: ModeCustom, + Discovery: discoverySnapshotForMeta(meta, req.NeedDeps), + }, nil +} + +func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + stopProfile, profileErr := startLoaderCPUProfile(req.Env) + if profileErr != nil { + return nil, profileErr + } + if stopProfile != nil { + defer stopProfile() + } + var ( + meta map[string]*packageMeta + err error + ) + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + meta = req.Discovery.meta + } else { + meta, err = runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: []string{req.Package}, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + } + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + fset := req.Fset + if fset == nil { + fset = token.NewFileSet() + } + l := &customTypedGraphLoader{ + workspace: detectModuleRoot(req.WD), + ctx: ctx, + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + root, err := l.loadPackage(req.Package) + if err != nil { + return nil, err + } + logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) + logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) + logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) + logDuration(ctx, "loader.custom.lazy.read_files.local.cumulative", l.stats.localRead) + logDuration(ctx, "loader.custom.lazy.read_files.external.cumulative", l.stats.externalRead) + logDuration(ctx, "loader.custom.lazy.parse_files.local.cumulative", l.stats.localParse) + logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) + logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) + logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) + return &LazyLoadResult{ + Packages: []*packages.Package{root}, + Backend: ModeCustom, + }, nil +} + +func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + meta, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + fset := req.Fset + if fset == nil { + fset = token.NewFileSet() + } + targets := make(map[string]struct{}) + for _, m := range meta { + if m.DepOnly { + continue + } + targets[m.ImportPath] = struct{}{} + } + if len(targets) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + l := &customTypedGraphLoader{ + workspace: detectModuleRoot(req.WD), + ctx: ctx, + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + } + roots := make([]*packages.Package, 0, len(targets)) + for _, m := range meta { + if m.DepOnly { + continue + } + root, err := l.loadPackage(m.ImportPath) + if err != nil { + return nil, err + } + roots = append(roots, root) + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) + logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) + logDuration(ctx, "loader.custom.typed.typecheck.cumulative", l.stats.typecheck) + logDuration(ctx, "loader.custom.typed.read_files.local.cumulative", l.stats.localRead) + logDuration(ctx, "loader.custom.typed.read_files.external.cumulative", l.stats.externalRead) + logDuration(ctx, "loader.custom.typed.parse_files.local.cumulative", l.stats.localParse) + logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) + logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) + logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) + return &PackageLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func (v *customValidator) validatePackage(path string) (*packages.Package, error) { + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing metadata for touched package"} + } + if v.loading[path] { + return nil, unsupportedError{reason: "touched package cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + pkg := &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: v.fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } + if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + return pkg, nil + } + files, errs := v.parseFiles(metaFiles(meta)) + pkg.Errors = append(pkg.Errors, errs...) + if len(files) == 0 { + return pkg, nil + } + + tpkg := types.NewPackage(meta.ImportPath, meta.Name) + v.packages[meta.ImportPath] = tpkg + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + importer := importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + if _, ok := v.touched[target]; ok { + if typed := v.packages[target]; typed != nil && typed.Complete() { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } + return typed, nil + } + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = checked + if len(checked.Errors) > 0 { + return nil, fmt.Errorf("touched dependency %s has errors", target) + } + if typed := v.packages[target]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "missing typed touched dependency"} + } + dep, err := v.importFromExport(target) + if err == nil { + if depMeta := v.meta[target]; depMeta != nil { + pkg.Imports[importPath] = touchedPackageStub(v.fset, depMeta) + } else { + pkg.Imports[importPath] = &packages.Package{PkgPath: target, Name: dep.Name()} + } + } + return dep, err + }) + var typeErrors []packages.Error + cfg := &types.Config{ + Importer: importer, + Sizes: types.SizesFor("gc", runtime.GOARCH), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + }, + } + checker := types.NewChecker(cfg, v.fset, tpkg, info) + if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(v.fset, err)) + } + pkg.Syntax = files + pkg.Types = tpkg + pkg.TypesInfo = info + typeErrors = append(typeErrors, v.validateDeclaredImports(meta, files)...) + pkg.Errors = append(pkg.Errors, typeErrors...) + return pkg, nil +} + +func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, error) { + if pkg := l.packages[path]; pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + return pkg, nil + } + meta := l.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing lazy-load metadata"} + } + if l.loading[path] { + if pkg := l.packages[path]; pkg != nil { + return pkg, nil + } + return nil, unsupportedError{reason: "lazy-load cycle"} + } + l.loading[path] = true + defer delete(l.loading, path) + l.stats.packages++ + isLocal := isWorkspacePackage(l.workspace, meta.Dir) + if isLocal { + l.stats.localPackages++ + } else { + l.stats.externalPackages++ + } + + pkg := l.packages[path] + if pkg == nil { + pkg = &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: l.fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } + l.packages[path] = pkg + } + files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) + pkg.Errors = append(pkg.Errors, parseErrs...) + if len(files) == 0 { + if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + } + return pkg, nil + } + + tpkg := l.typesPkgs[path] + if tpkg == nil { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + l.typesPkgs[path] = tpkg + } + _, isTarget := l.targets[path] + needFullState := isTarget || isWorkspacePackage(l.workspace, meta.Dir) + var info *types.Info + if needFullState { + info = &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + } + var typeErrors []packages.Error + cfg := &types.Config{ + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: !isWorkspacePackage(l.workspace, meta.Dir), + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + return nil, err + } + pkg.Imports[importPath] = dep + if dep.Types != nil { + return dep.Types, nil + } + if typed := l.typesPkgs[target]; typed != nil { + return typed, nil + } + if len(dep.Errors) > 0 { + return nil, unsupportedError{reason: "lazy-load dependency has errors"} + } + return nil, unsupportedError{reason: "missing typed lazy-load dependency"} + }), + Error: func(err error) { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + }, + } + checker := types.NewChecker(cfg, l.fset, tpkg, info) + typecheckStart := time.Now() + if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + typeErrors = append(typeErrors, toPackagesError(l.fset, err)) + } + typecheckDuration := time.Since(typecheckStart) + l.stats.typecheck += typecheckDuration + if isLocal { + l.stats.localTypecheck += typecheckDuration + } else { + l.stats.externalTypecheck += typecheckDuration + } + if needFullState { + pkg.Syntax = files + } else { + pkg.Syntax = nil + } + pkg.Types = tpkg + pkg.TypesInfo = info + pkg.Errors = append(pkg.Errors, typeErrors...) + return pkg, nil +} + +func (v *customValidator) importFromExport(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + if v.importer != nil { + if imported, err := v.importer.Import(path); err == nil { + v.packages[path] = imported + return imported, nil + } + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing dependency metadata"} + } + if meta.Export == "" { + return v.loadDependencyFromSource(path) + } + exportPath := meta.Export + if !filepath.IsAbs(exportPath) { + exportPath = filepath.Join(meta.Dir, exportPath) + } + f, err := os.Open(exportPath) + if err != nil { + return nil, unsupportedError{reason: "open export data"} + } + defer f.Close() + r, err := gcexportdata.NewReader(f) + if err != nil { + return nil, unsupportedError{reason: "read export data"} + } + view := make(map[string]*types.Package, len(v.packages)) + for pkgPath, pkg := range v.packages { + view[pkgPath] = pkg + } + tpkg, err := gcexportdata.Read(r, v.fset, view, path) + if err != nil { + return v.loadDependencyFromSource(path) + } + v.packages[path] = tpkg + return tpkg, nil +} + +func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, error) { + if typed := v.packages[path]; typed != nil && typed.Complete() { + return typed, nil + } + meta := v.meta[path] + if meta == nil { + return nil, unsupportedError{reason: "missing source dependency metadata"} + } + if v.loading[path] { + if typed := v.packages[path]; typed != nil { + return typed, nil + } + return nil, unsupportedError{reason: "dependency cycle"} + } + v.loading[path] = true + defer delete(v.loading, path) + + tpkg := v.packages[path] + if tpkg == nil { + tpkg = types.NewPackage(meta.ImportPath, meta.Name) + v.packages[path] = tpkg + } + files, errs := v.parseFiles(metaFiles(meta)) + if len(errs) > 0 { + return nil, unsupportedError{reason: "dependency parse error"} + } + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } + cfg := &types.Config{ + Importer: importerFunc(func(importPath string) (*types.Package, error) { + if importPath == "unsafe" { + return types.Unsafe, nil + } + target := importPath + if mapped := meta.ImportMap[importPath]; mapped != "" { + target = mapped + } + if _, ok := v.touched[target]; ok { + checked, err := v.validatePackage(target) + if err != nil { + return nil, err + } + if len(checked.Errors) > 0 { + return nil, unsupportedError{reason: "touched dependency has validation errors"} + } + return v.packages[target], nil + } + return v.importFromExport(target) + }), + Sizes: types.SizesFor("gc", runtime.GOARCH), + IgnoreFuncBodies: true, + } + if err := types.NewChecker(cfg, v.fset, tpkg, info).Files(files); err != nil { + return nil, unsupportedError{reason: "dependency typecheck error"} + } + return tpkg, nil +} + +func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + src, err := os.ReadFile(name) + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) + if err != nil { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*ast.File, []packages.Error) { + files := make([]*ast.File, 0, len(names)) + var errs []packages.Error + for _, name := range names { + readStart := time.Now() + src, err := os.ReadFile(name) + readDuration := time.Since(readStart) + l.stats.read += readDuration + l.stats.filesRead++ + if isLocal { + l.stats.localRead += readDuration + l.stats.localFilesRead++ + } else { + l.stats.externalRead += readDuration + l.stats.externalFilesRead++ + } + if err != nil { + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + continue + } + var f *ast.File + parseStart := time.Now() + if l.parseFile != nil { + f, err = l.parseFile(l.fset, name, src) + } else { + f, err = parser.ParseFile(l.fset, name, src, parser.AllErrors|parser.ParseComments) + } + parseDuration := time.Since(parseStart) + l.stats.parse += parseDuration + if isLocal { + l.stats.localParse += parseDuration + } else { + l.stats.externalParse += parseDuration + } + if err != nil { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + } + if f != nil { + files = append(files, f) + } + } + return files, errs +} + +func toPackagesError(fset *token.FileSet, err error) packages.Error { + switch typed := err.(type) { + case packages.Error: + return typed + case types.Error: + return packages.Error{ + Pos: typed.Fset.Position(typed.Pos).String(), + Msg: typed.Msg, + Kind: packages.TypeError, + } + default: + pos := "-" + if fset != nil { + if te, ok := err.(interface{ Pos() token.Pos }); ok { + pos = fset.Position(te.Pos()).String() + } + } + return packages.Error{Pos: pos, Msg: err.Error(), Kind: packages.UnknownError} + } +} + +type importerFunc func(path string) (*types.Package, error) + +func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } + +func (v *customValidator) validateDeclaredImports(meta *packageMeta, files []*ast.File) []packages.Error { + var errs []packages.Error + for _, file := range files { + used := usedImportsInFile(file) + for _, spec := range file.Imports { + if spec == nil || spec.Path == nil { + continue + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + continue + } + target := path + if mapped := meta.ImportMap[path]; mapped != "" { + target = mapped + } + name := importName(spec) + if name != "_" && name != "." { + if _, ok := used[name]; !ok { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("%q imported and not used", path), + Kind: packages.TypeError, + }) + continue + } + } + if _, err := v.importFromExport(target); err != nil { + errs = append(errs, packages.Error{ + Pos: v.fset.Position(spec.Pos()).String(), + Msg: fmt.Sprintf("could not import %s", path), + Kind: packages.TypeError, + }) + } + } + } + return errs +} + +func usedImportsInFile(file *ast.File) map[string]struct{} { + used := make(map[string]struct{}) + ast.Inspect(file, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name == "" { + return true + } + used[ident.Name] = struct{}{} + return true + }) + return used +} + +func importName(spec *ast.ImportSpec) string { + if spec == nil || spec.Path == nil { + return "" + } + if spec.Name != nil && spec.Name.Name != "" { + return spec.Name.Name + } + path := strings.Trim(spec.Path.Value, "\"") + if path == "" { + return "" + } + if slash := strings.LastIndex(path, "/"); slash >= 0 { + path = path[slash+1:] + } + return path +} + +func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) (map[string]*packageMeta, error) { + metas, err := runGoList(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Touched, + NeedDeps: true, + }) + if err != nil { + return nil, err + } + if len(metas) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + for _, touched := range req.Touched { + if _, ok := metas[touched]; !ok { + return nil, unsupportedError{reason: "missing touched package in metadata"} + } + } + return metas, nil +} + +func normalizeImports(imports []string, importMap map[string]string) []string { + if len(imports) == 0 { + return nil + } + out := make([]string, 0, len(imports)) + for _, imp := range imports { + if mapped := importMap[imp]; mapped != "" { + out = append(out, mapped) + continue + } + out = append(out, imp) + } + sort.Strings(out) + return out +} + +func metaFiles(meta *packageMeta) []string { + if meta == nil { + return nil + } + if len(meta.CompiledGoFiles) > 0 { + return meta.CompiledGoFiles + } + return meta.GoFiles +} + +func discoverySnapshotForMeta(meta map[string]*packageMeta, complete bool) *DiscoverySnapshot { + if !complete || len(meta) == 0 { + return nil + } + return &DiscoverySnapshot{meta: meta} +} + +func isWorkspacePackage(workspaceRoot, dir string) bool { + if workspaceRoot == "" || dir == "" { + return false + } + workspaceRoot = canonicalLoaderPath(workspaceRoot) + dir = canonicalLoaderPath(dir) + if dir == workspaceRoot { + return true + } + rel, err := filepath.Rel(workspaceRoot, dir) + if err != nil { + return false + } + return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) +} + +func detectModuleRoot(start string) string { + start = canonicalLoaderPath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalLoaderPath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func startLoaderCPUProfile(env []string) (func(), error) { + path := envValue(env, "WIRE_LOADER_CPU_PROFILE") + if strings.TrimSpace(path) == "" { + return nil, nil + } + f, err := os.Create(path) + if err != nil { + return nil, err + } + if err := pprof.StartCPUProfile(f); err != nil { + _ = f.Close() + return nil, err + } + return func() { + pprof.StopCPUProfile() + _ = f.Close() + }, nil +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} + + +func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + if meta == nil { + return nil + } + return &packages.Package{ + ID: meta.ImportPath, + Name: meta.Name, + PkgPath: meta.ImportPath, + Fset: fset, + GoFiles: append([]string(nil), metaFiles(meta)...), + CompiledGoFiles: append([]string(nil), metaFiles(meta)...), + Imports: make(map[string]*packages.Package), + ExportFile: meta.Export, + } +} + +func metadataMatchesFingerprint(pkgPath string, meta map[string]*packageMeta, local []LocalPackageFingerprint) bool { + for _, fp := range local { + if fp.PkgPath != pkgPath { + continue + } + pm := meta[pkgPath] + if pm == nil { + return false + } + want := append([]string(nil), fp.Files...) + got := append([]string(nil), metaFiles(pm)...) + sort.Strings(want) + sort.Strings(got) + if len(want) != len(got) { + return false + } + for i := range want { + if filepath.Clean(want[i]) != filepath.Clean(got[i]) { + return false + } + } + return true + } + return true +} diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go new file mode 100644 index 0000000..a6aba46 --- /dev/null +++ b/internal/loader/discovery.go @@ -0,0 +1,94 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" +) + +type goListRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool +} + +func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + args := []string{"list", "-json", "-e", "-compiled", "-export"} + if req.NeedDeps { + args = append(args, "-deps") + } + if req.Tags != "" { + args = append(args, "-tags=wireinject "+req.Tags) + } else { + args = append(args, "-tags=wireinject") + } + args = append(args, "--") + args = append(args, req.Patterns...) + + cmd := exec.CommandContext(ctx, "go", args...) + cmd.Dir = req.WD + if len(req.Env) > 0 { + cmd.Env = req.Env + } else { + cmd.Env = os.Environ() + } + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("go list: %w: %s", err, stderr.String()) + } + dec := json.NewDecoder(&stdout) + out := make(map[string]*packageMeta) + for { + var meta packageMeta + if err := dec.Decode(&meta); err != nil { + if err == io.EOF { + break + } + return nil, err + } + if meta.ImportPath == "" { + continue + } + for i, name := range meta.GoFiles { + if !filepath.IsAbs(name) { + meta.GoFiles[i] = filepath.Join(meta.Dir, name) + } + } + for i, name := range meta.CompiledGoFiles { + if !filepath.IsAbs(name) { + meta.CompiledGoFiles[i] = filepath.Join(meta.Dir, name) + } + } + if meta.Export != "" && !filepath.IsAbs(meta.Export) { + meta.Export = filepath.Join(meta.Dir, meta.Export) + } + meta.Imports = normalizeImports(meta.Imports, meta.ImportMap) + copyMeta := meta + out[meta.ImportPath] = ©Meta + } + return out, nil +} diff --git a/internal/loader/fallback.go b/internal/loader/fallback.go new file mode 100644 index 0000000..513694c --- /dev/null +++ b/internal/loader/fallback.go @@ -0,0 +1,224 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "errors" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type defaultLoader struct{} + +func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &PackageLoadResult{ + Backend: ModeFallback, + } + switch req.LoaderMode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadRootGraph(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := loadRootGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &RootLoadResult{ + Backend: ModeFallback, + } + switch req.Mode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if req.NeedDeps { + cfg.Mode |= packages.NeedDeps + } + if req.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + escaped := make([]string, len(req.Patterns)) + for i := range req.Patterns { + escaped[i] = "pattern=" + req.Patterns[i] + } + pkgs, err := packages.Load(cfg, escaped...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) LoadTypedPackageGraph(ctx context.Context, req LazyLoadRequest) (*LazyLoadResult, error) { + var unsupported unsupportedError + if req.LoaderMode != ModeFallback { + result, err := loadTypedPackageGraphCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + result := &LazyLoadResult{ + Backend: ModeFallback, + } + switch req.LoaderMode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + if unsupported.reason != "" { + result.FallbackDetail = unsupported.reason + } + } + cfg := &packages.Config{ + Context: ctx, + Mode: req.Mode, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: req.Fset, + } + if cfg.Fset == nil { + cfg.Fset = token.NewFileSet() + } + if req.ParseFile != nil { + cfg.ParseFile = req.ParseFile + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, "pattern="+req.Package) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} + +func (defaultLoader) ValidateTouchedPackages(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { + var unsupported unsupportedError + if req.Mode != ModeFallback { + result, err := validateTouchedPackagesCustom(ctx, req) + if err == nil { + return result, nil + } + if !errors.As(err, &unsupported) { + return nil, err + } + } + return validateTouchedPackagesFallback(ctx, req, unsupported.reason) +} + +func validateTouchedPackagesFallback(ctx context.Context, req TouchedValidationRequest, detail string) (*TouchedValidationResult, error) { + result := &TouchedValidationResult{ + Backend: ModeFallback, + } + switch req.Mode { + case ModeFallback: + result.FallbackReason = FallbackReasonForcedFallback + default: + result.FallbackReason = FallbackReasonCustomUnsupported + result.FallbackDetail = detail + } + if len(req.Touched) == 0 { + return result, nil + } + cfg := &packages.Config{ + Context: ctx, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, + Dir: req.WD, + Env: req.Env, + BuildFlags: []string{"-tags=wireinject"}, + Fset: token.NewFileSet(), + } + if req.Tags != "" { + cfg.BuildFlags[0] += " " + req.Tags + } + pkgs, err := packages.Load(cfg, req.Touched...) + if err != nil { + return nil, err + } + result.Packages = pkgs + return result, nil +} diff --git a/internal/loader/loader.go b/internal/loader/loader.go new file mode 100644 index 0000000..e26747b --- /dev/null +++ b/internal/loader/loader.go @@ -0,0 +1,135 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "go/ast" + "go/token" + + "golang.org/x/tools/go/packages" +) + +type Mode string + +const ( + ModeAuto Mode = "auto" + ModeCustom Mode = "custom" + ModeFallback Mode = "fallback" +) + +type FallbackReason string + +const ( + FallbackReasonNone FallbackReason = "" + FallbackReasonForcedFallback FallbackReason = "forced_fallback" + FallbackReasonCustomNotImplemented FallbackReason = "custom_not_implemented" + FallbackReasonCustomUnsupported FallbackReason = "custom_unsupported" +) + +type LocalPackageFingerprint struct { + PkgPath string + ContentHash string + ShapeHash string + Files []string +} + +type DiscoverySnapshot struct { + meta map[string]*packageMeta +} + +type TouchedValidationRequest struct { + WD string + Env []string + Tags string + Touched []string + Local []LocalPackageFingerprint + Mode Mode +} + +type TouchedValidationResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type RootLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool + Mode Mode + Fset *token.FileSet +} + +type RootLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string + Discovery *DiscoverySnapshot +} + +type PackageLoadRequest struct { + WD string + Env []string + Tags string + Patterns []string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc +} + +type PackageLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type ParseFileFunc func(*token.FileSet, string, []byte) (*ast.File, error) + +type LazyLoadRequest struct { + WD string + Env []string + Tags string + Package string + Mode packages.LoadMode + LoaderMode Mode + Fset *token.FileSet + ParseFile ParseFileFunc + Discovery *DiscoverySnapshot +} + +type LazyLoadResult struct { + Packages []*packages.Package + Backend Mode + FallbackReason FallbackReason + FallbackDetail string +} + +type Loader interface { + LoadPackages(context.Context, PackageLoadRequest) (*PackageLoadResult, error) + LoadRootGraph(context.Context, RootLoadRequest) (*RootLoadResult, error) + LoadTypedPackageGraph(context.Context, LazyLoadRequest) (*LazyLoadResult, error) + ValidateTouchedPackages(context.Context, TouchedValidationRequest) (*TouchedValidationResult, error) +} + +func New() Loader { + return defaultLoader{} +} diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go new file mode 100644 index 0000000..0e38d99 --- /dev/null +++ b/internal/loader/loader_test.go @@ -0,0 +1,821 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "context" + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "sort" + "strings" + "testing" + + "golang.org/x/tools/go/packages" +) + +func TestModeFromEnvDefaultAuto(t *testing.T) { + if got := ModeFromEnv(nil); got != ModeAuto { + t.Fatalf("ModeFromEnv(nil) = %q, want %q", got, ModeAuto) + } +} + +func TestModeFromEnvUsesLastMatchingValue(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=fallback", + "OTHER=value", + "WIRE_LOADER_MODE=custom", + } + if got := ModeFromEnv(env); got != ModeCustom { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeCustom) + } +} + +func TestModeFromEnvIgnoresInvalidValues(t *testing.T) { + env := []string{ + "WIRE_LOADER_MODE=invalid", + } + if got := ModeFromEnv(env); got != ModeAuto { + t.Fatalf("ModeFromEnv(...) = %q, want %q", got, ModeAuto) + } +} + +func TestFallbackLoaderReasonFromMode(t *testing.T) { + l := New() + + gotAuto, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if gotAuto.Backend != ModeCustom { + t.Fatalf("auto backend = %q, want %q", gotAuto.Backend, ModeCustom) + } + if gotAuto.FallbackReason != FallbackReasonNone { + t.Fatalf("auto fallback reason = %q, want none", gotAuto.FallbackReason) + } + + gotForced, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: ".", + Env: []string{}, + Touched: []string{}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + if gotForced.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("forced fallback reason = %q, want %q", gotForced.FallbackReason, FallbackReasonForcedFallback) + } +} + +func TestCustomTouchedValidationSuccess(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nimport \"fmt\"\n\nfunc Use() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) != 0 { + t.Fatalf("unexpected package errors: %+v", got.Packages[0].Errors) + } +} + +func TestValidateTouchedPackagesAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } +} + +func TestCustomTouchedValidationTypeError(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if len(got.Packages[0].Errors) == 0 { + t.Fatal("expected type-check errors") + } +} + +func TestValidateTouchedPackagesCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Use() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/app"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(custom) error = %v", err) + } + fallback, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, false) +} + +func TestValidateTouchedPackagesAutoReportsFallbackDetail(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "a", "a.go"), "package a\n\nfunc Use() string { return \"ok\" }\n") + + l := New() + got, err := l.ValidateTouchedPackages(context.Background(), TouchedValidationRequest{ + WD: root, + Env: os.Environ(), + Touched: []string{"example.com/app/a"}, + Local: []LocalPackageFingerprint{ + { + PkgPath: "example.com/app/a", + ContentHash: "wrong", + ShapeHash: "wrong", + Files: []string{filepath.Join(root, "a", "a.go")}, + }, + }, + Mode: ModeAuto, + }) + if err != nil { + t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonCustomUnsupported { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) + } + if got.FallbackDetail != "metadata fingerprint mismatch" { + t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + } +} + +func TestLoadRootGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"fmt\"\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + }) + if err != nil { + t.Fatalf("LoadRootGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) == 0 { + t.Fatal("expected loaded root packages") + } +} + +func TestLoadRootGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if got.Packages[0].Imports["example.com/app/dep"] == nil { + t.Fatal("expected custom root graph to wire local import dependency") + } +} + +func TestLoadRootGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeAuto, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestMetaFilesFallsBackToGoFiles(t *testing.T) { + meta := &packageMeta{ + GoFiles: []string{"a.go", "b.go"}, + } + got := metaFiles(meta) + if len(got) != 2 || got[0] != "a.go" || got[1] != "b.go" { + t.Fatalf("metaFiles(go-only) = %v, want GoFiles fallback", got) + } + + meta.CompiledGoFiles = []string{"c.go"} + got = metaFiles(meta) + if len(got) != 1 || got[0] != "c.go" { + t.Fatalf("metaFiles(compiled) = %v, want CompiledGoFiles", got) + } +} + +func TestLoadTypedPackageGraphFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "app.go"), "package app\n\nfunc Value() int { return 42 }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph error = %v", err) + } + if got.Backend != ModeFallback { + t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) + } + if got.FallbackReason != FallbackReasonForcedFallback { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonForcedFallback) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + if parseCalls == 0 { + t.Fatal("expected ParseFile hook to be used") + } +} + +func TestLoadTypedPackageGraphCustom(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want none", got.FallbackReason) + } + if len(got.Packages) != 1 { + t.Fatalf("packages len = %d, want 1", len(got.Packages)) + } + rootPkg := got.Packages[0] + if rootPkg.Types == nil || rootPkg.TypesInfo == nil || len(rootPkg.Syntax) == 0 { + t.Fatalf("root package missing typed syntax: %+v", rootPkg) + } + depPkg := rootPkg.Imports["example.com/app/dep"] + if depPkg == nil || depPkg.Types == nil || len(depPkg.Syntax) == 0 { + t.Fatalf("dep package missing typed syntax: %+v", depPkg) + } + if parseCalls < 2 { + t.Fatalf("parseCalls = %d, want at least 2", parseCalls) + } +} + +func TestLoadTypedPackageGraphAutoUsesCustomWhenSupported(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeAuto, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(auto) error = %v", err) + } + if got.Backend != ModeCustom { + t.Fatalf("backend = %q, want %q", got.Backend, ModeCustom) + } +} + +func TestLoadTypedPackageGraphCustomKeepsExternalPackagesLight(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + fmtPkg := rootPkg.Imports["fmt"] + if fmtPkg == nil { + t.Fatal("expected fmt import package") + } + if fmtPkg.Types == nil { + t.Fatalf("fmt package missing types: %+v", fmtPkg) + } + if fmtPkg.TypesInfo != nil { + t.Fatalf("fmt package TypesInfo should be nil, got %+v", fmtPkg.TypesInfo) + } + if len(fmtPkg.Syntax) != 0 { + t.Fatalf("fmt package Syntax len = %d, want 0", len(fmtPkg.Syntax)) + } +} + +func TestLoadRootGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/dep\"\n") + + l := New() + custom, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeCustom, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(custom) error = %v", err) + } + fallback, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: root, + Env: os.Environ(), + Patterns: []string{"./app"}, + NeedDeps: true, + Mode: ModeFallback, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, false) +} + +func TestLoadTypedPackageGraphCustomMatchesFallback(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\ntype T struct{}\nfunc New() *T { return &T{} }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() *dep.T { return dep.New() }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomMatchesFallbackTypeErrors(t *testing.T) { + root := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nfunc Broken() int { return missing }\n") + + l := New() + custom, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + fallback, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: os.Environ(), + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeFallback, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(fallback) error = %v", err) + } + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func comparePackageGraphs(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + gotAll := collectGraph(got) + wantAll := collectGraph(want) + if len(gotAll) != len(wantAll) { + t.Fatalf("package graph size = %d, want %d", len(gotAll), len(wantAll)) + } + for path, wantPkg := range wantAll { + gotPkg := gotAll[path] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", path) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", path, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", path, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", path, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", path, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", path, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", path, gotTyped, wantTyped) + } + } + } +} + +func compareRootPackagesOnly(t *testing.T, got []*packages.Package, want []*packages.Package, requireTyped bool) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("root package count = %d, want %d", len(got), len(want)) + } + gotByPath := make(map[string]*packages.Package, len(got)) + for _, pkg := range got { + gotByPath[pkg.PkgPath] = pkg + } + for _, wantPkg := range want { + gotPkg := gotByPath[wantPkg.PkgPath] + if gotPkg == nil { + t.Fatalf("missing root package %q", wantPkg.PkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", wantPkg.PkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", wantPkg.PkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", wantPkg.PkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", wantPkg.PkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", wantPkg.PkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", wantPkg.PkgPath, gotTyped, wantTyped) + } + } + } +} + +func collectGraph(roots []*packages.Package) map[string]*packages.Package { + out := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), roots...) + for len(stack) > 0 { + pkg := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if pkg == nil || out[pkg.PkgPath] != nil { + continue + } + out[pkg.PkgPath] = pkg + for _, imp := range pkg.Imports { + stack = append(stack, imp) + } + } + return out +} + +func equalStrings(a, b []string) bool { + if len(a) != len(b) { + return false + } + aCopy := append([]string(nil), a...) + bCopy := append([]string(nil), b...) + for i := range aCopy { + aCopy[i] = normalizePathForCompare(aCopy[i]) + } + for i := range bCopy { + bCopy[i] = normalizePathForCompare(bCopy[i]) + } + sort.Strings(aCopy) + sort.Strings(bCopy) + for i := range aCopy { + if aCopy[i] != bCopy[i] { + return false + } + } + return true +} + +func equalImportPaths(a, b map[string]*packages.Package) bool { + return equalStrings(sortedImportPaths(a), sortedImportPaths(b)) +} + +func sortedImportPaths(m map[string]*packages.Package) []string { + out := make([]string, 0, len(m)) + for path := range m { + out = append(out, path) + } + sort.Strings(out) + return out +} + +func normalizePathForCompare(path string) string { + if path == "" { + return "" + } + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return filepath.Clean(path) +} + +func comparableErrors(errs []packages.Error) []string { + seen := make(map[string]struct{}, len(errs)) + out := make([]string, 0, len(errs)) + add := func(value string) { + if _, ok := seen[value]; ok { + return + } + seen[value] = struct{}{} + out = append(out, value) + } + for _, err := range errs { + if strings.HasPrefix(err.Msg, "# ") { + for _, value := range expandSummaryDiagnostics(err.Msg) { + add(value) + } + continue + } + pos := normalizeErrorPos(err.Pos) + add(pos + "|" + err.Msg) + } + sort.Strings(out) + return out +} + +func normalizeErrorPos(pos string) string { + if pos == "" || pos == "-" { + return pos + } + parts := strings.Split(pos, ":") + if len(parts) < 2 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + path := shortenComparablePath(normalizePathForCompare(parts[0])) + return strings.Join(append([]string{path}, parts[1:]...), ":") +} + +func expandSummaryDiagnostics(msg string) []string { + lines := strings.Split(msg, "\n") + out := make([]string, 0, len(lines)) + for _, line := range lines[1:] { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if parts := strings.SplitN(line, ": ", 2); len(parts) == 2 { + pos := normalizeErrorPos(parts[0]) + out = append(out, pos+"|"+parts[1]) + continue + } + out = append(out, line) + } + return out +} + +func shortenComparablePath(path string) string { + path = filepath.Clean(path) + parts := strings.Split(path, string(filepath.Separator)) + if len(parts) >= 2 { + return filepath.Join(parts[len(parts)-2], parts[len(parts)-1]) + } + return path +} + +func writeTestFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/internal/loader/mode.go b/internal/loader/mode.go new file mode 100644 index 0000000..b08710b --- /dev/null +++ b/internal/loader/mode.go @@ -0,0 +1,38 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import "strings" + +const ModeEnvVar = "WIRE_LOADER_MODE" + +func ModeFromEnv(env []string) Mode { + mode := ModeAuto + for _, entry := range env { + name, value, ok := strings.Cut(entry, "=") + if !ok || name != ModeEnvVar { + continue + } + switch strings.ToLower(strings.TrimSpace(value)) { + case string(ModeCustom): + mode = ModeCustom + case string(ModeFallback): + mode = ModeFallback + case "", string(ModeAuto): + mode = ModeAuto + } + } + return mode +} diff --git a/internal/loader/timing.go b/internal/loader/timing.go new file mode 100644 index 0000000..0211f17 --- /dev/null +++ b/internal/loader/timing.go @@ -0,0 +1,41 @@ +package loader + +import ( + "context" + "time" +) + +type timingLogger func(string, time.Duration) + +type timingKey struct{} + +func WithTiming(ctx context.Context, logf func(string, time.Duration)) context.Context { + if logf == nil { + return ctx + } + return context.WithValue(ctx, timingKey{}, timingLogger(logf)) +} + +func timing(ctx context.Context) timingLogger { + if ctx == nil { + return nil + } + if v := ctx.Value(timingKey{}); v != nil { + if t, ok := v.(timingLogger); ok { + return t + } + } + return nil +} + +func logTiming(ctx context.Context, label string, start time.Time) { + if t := timing(ctx); t != nil { + t(label, time.Since(start)) + } +} + +func logDuration(ctx context.Context, label string, d time.Duration) { + if t := timing(ctx); t != nil { + t(label, d) + } +} diff --git a/internal/wire/cache_bypass.go b/internal/wire/cache_bypass.go deleted file mode 100644 index b195eef..0000000 --- a/internal/wire/cache_bypass.go +++ /dev/null @@ -1,17 +0,0 @@ -package wire - -import "context" - -type bypassPackageCacheKey struct{} - -func withBypassPackageCache(ctx context.Context) context.Context { - return context.WithValue(ctx, bypassPackageCacheKey{}, true) -} - -func bypassPackageCache(ctx context.Context) bool { - if ctx == nil { - return false - } - v, _ := ctx.Value(bypassPackageCacheKey{}).(bool) - return v -} diff --git a/internal/wire/cache_coverage_test.go b/internal/wire/cache_coverage_test.go deleted file mode 100644 index faf6e62..0000000 --- a/internal/wire/cache_coverage_test.go +++ /dev/null @@ -1,1099 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "errors" - "io/fs" - "os" - "path/filepath" - "sort" - "sync" - "testing" - - "golang.org/x/tools/go/packages" -) - -type cacheHookState struct { - osCreateTemp func(string, string) (*os.File, error) - osMkdirAll func(string, os.FileMode) error - osReadFile func(string) ([]byte, error) - osRemove func(string) error - osRemoveAll func(string) error - osRename func(string, string) error - osStat func(string) (os.FileInfo, error) - osTempDir func() string - jsonMarshal func(any) ([]byte, error) - jsonUnmarshal func([]byte, any) error - extraCachePathsFunc func(string) []string - cacheKeyForPackage func(*packages.Package, *GenerateOptions) (string, error) - detectOutputDir func([]string) (string, error) - buildCacheFiles func([]string) ([]cacheFile, error) - buildCacheFilesFrom func([]cacheFile) ([]cacheFile, error) - rootPackageFiles func(*packages.Package) []string - hashFiles func([]string) (string, error) -} - -var cacheHooksMu sync.Mutex - -func lockCacheHooks(t *testing.T) { - t.Helper() - cacheHooksMu.Lock() - t.Cleanup(func() { - cacheHooksMu.Unlock() - }) -} - -func saveCacheHooks() cacheHookState { - return cacheHookState{ - osCreateTemp: osCreateTemp, - osMkdirAll: osMkdirAll, - osReadFile: osReadFile, - osRemove: osRemove, - osRemoveAll: osRemoveAll, - osRename: osRename, - osStat: osStat, - osTempDir: osTempDir, - jsonMarshal: jsonMarshal, - jsonUnmarshal: jsonUnmarshal, - extraCachePathsFunc: extraCachePathsFunc, - cacheKeyForPackage: cacheKeyForPackageFunc, - detectOutputDir: detectOutputDirFunc, - buildCacheFiles: buildCacheFilesFunc, - buildCacheFilesFrom: buildCacheFilesFromMetaFunc, - rootPackageFiles: rootPackageFilesFunc, - hashFiles: hashFilesFunc, - } -} - -func restoreCacheHooks(state cacheHookState) { - osCreateTemp = state.osCreateTemp - osMkdirAll = state.osMkdirAll - osReadFile = state.osReadFile - osRemove = state.osRemove - osRemoveAll = state.osRemoveAll - osRename = state.osRename - osStat = state.osStat - osTempDir = state.osTempDir - jsonMarshal = state.jsonMarshal - jsonUnmarshal = state.jsonUnmarshal - extraCachePathsFunc = state.extraCachePathsFunc - cacheKeyForPackageFunc = state.cacheKeyForPackage - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = state.buildCacheFiles - buildCacheFilesFromMetaFunc = state.buildCacheFilesFrom - rootPackageFilesFunc = state.rootPackageFiles - hashFilesFunc = state.hashFiles -} - -func writeTempFile(t *testing.T, dir, name, content string) string { - t.Helper() - path := filepath.Join(dir, name) - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile(%s) failed: %v", path, err) - } - return path -} - -func cloneManifest(src *cacheManifest) *cacheManifest { - if src == nil { - return nil - } - dst := *src - if src.Patterns != nil { - dst.Patterns = append([]string(nil), src.Patterns...) - } - if src.ExtraFiles != nil { - dst.ExtraFiles = append([]cacheFile(nil), src.ExtraFiles...) - } - if src.Packages != nil { - dst.Packages = make([]manifestPackage, len(src.Packages)) - for i, pkg := range src.Packages { - dstPkg := pkg - if pkg.Files != nil { - dstPkg.Files = append([]cacheFile(nil), pkg.Files...) - } - if pkg.RootFiles != nil { - dstPkg.RootFiles = append([]cacheFile(nil), pkg.RootFiles...) - } - dst.Packages[i] = dstPkg - } - } - return &dst -} - -func TestCacheStoreReadWrite(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if got := CacheDir(); got == "" { - t.Fatal("expected CacheDir to return a value") - } - - key := "cache-store" - want := []byte("content") - writeCache(key, want) - - got, ok := readCache(key) - if !ok { - t.Fatal("expected cache hit") - } - if !bytes.Equal(got, want) { - t.Fatalf("cache content mismatch: got %q, want %q", got, want) - } - if err := ClearCache(); err != nil { - t.Fatalf("ClearCache failed: %v", err) - } - if _, ok := readCache(key); ok { - t.Fatal("expected cache miss after clear") - } -} - -func TestClearCacheClearsIncrementalSessions(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - sessionA := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionA == nil { - t.Fatal("expected incremental session") - } - sessionB := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionA != sessionB { - t.Fatal("expected same incremental session before clear") - } - - if err := ClearCache(); err != nil { - t.Fatalf("ClearCache failed: %v", err) - } - - sessionC := getIncrementalSession("/tmp/example", []string{"WIRE_INCREMENTAL=1"}, "") - if sessionC == nil { - t.Fatal("expected incremental session after clear") - } - if sessionC == sessionA { - t.Fatal("expected ClearCache to drop in-process incremental sessions") - } -} - -func TestCacheStoreReadError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osReadFile = func(string) ([]byte, error) { - return nil, errors.New("boom") - } - if _, ok := readCache("missing"); ok { - t.Fatal("expected cache miss on read error") - } -} - -func TestCacheStoreWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - t.Run("mkdir", func(t *testing.T) { - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCache("mkdir", []byte("data")) - }) - - t.Run("create", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { - return nil, errors.New("create") - } - writeCache("create", []byte("data")) - }) - - t.Run("write", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCache("write", []byte("data")) - }) - - t.Run("rename-exist", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return fs.ErrExist - } - writeCache("exist", []byte("data")) - }) - - t.Run("rename", func(t *testing.T) { - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { - return errors.New("rename") - } - writeCache("rename", []byte("data")) - }) -} - -func TestCacheDirError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - osRemoveAll = func(string) error { return errors.New("remove") } - if err := ClearCache(); err == nil { - t.Fatal("expected ClearCache error") - } -} - -func TestPackageFiles(t *testing.T) { - tempDir := t.TempDir() - rootFile := writeTempFile(t, tempDir, "root.go", "package root\n") - childFile := writeTempFile(t, tempDir, "child.go", "package child\n") - - child := &packages.Package{ - PkgPath: "example.com/child", - CompiledGoFiles: []string{childFile}, - } - root := &packages.Package{ - PkgPath: "example.com/root", - GoFiles: []string{rootFile}, - Imports: map[string]*packages.Package{ - "child": child, - "dup": child, - "nil": nil, - }, - } - got := packageFiles(root) - sort.Strings(got) - if len(got) != 2 { - t.Fatalf("expected 2 files, got %d", len(got)) - } - if got[0] != childFile || got[1] != rootFile { - t.Fatalf("unexpected files: %v", got) - } -} - -func TestCacheKeyEmptyPackage(t *testing.T) { - key, err := cacheKeyForPackage(&packages.Package{PkgPath: "example.com/empty"}, &GenerateOptions{}) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if key != "" { - t.Fatalf("expected empty cache key, got %q", key) - } -} - -func TestCacheKeyMetaHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - metaKey := cacheMetaKey(pkg, opts) - writeCacheMeta(metaKey, meta) - - got, err := cacheKeyForPackage(pkg, opts) - if err != nil { - t.Fatalf("cacheKeyForPackage error: %v", err) - } - if got != contentHash { - t.Fatalf("cache key mismatch: got %q, want %q", got, contentHash) - } -} - -func TestCacheKeyErrorPaths(t *testing.T) { - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(t.TempDir(), "missing.go")}, - } - if _, err := cacheKeyForPackage(pkg, &GenerateOptions{}); err == nil { - t.Fatal("expected cacheKeyForPackage error") - } - if _, err := buildCacheFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected buildCacheFiles error") - } - if _, err := contentHashForPaths("example.com/missing", &GenerateOptions{}, []string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected contentHashForPaths error") - } - if _, err := hashFiles([]string{filepath.Join(t.TempDir(), "missing.go")}); err == nil { - t.Fatal("expected hashFiles error") - } - if got, err := hashFiles(nil); err != nil || got != "" { - t.Fatalf("expected empty hashFiles result, got %q err=%v", got, err) - } -} - -func TestCacheMetaMatches(t *testing.T) { - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "meta.go", "package meta\n") - pkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - files := packageFiles(pkg) - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - if !cacheMetaMatches(meta, pkg, opts, files) { - t.Fatal("expected cacheMetaMatches to succeed") - } - badVersion := *meta - badVersion.Version = "nope" - if cacheMetaMatches(&badVersion, pkg, opts, files) { - t.Fatal("expected version mismatch") - } - badPkg := *meta - badPkg.PkgPath = "example.com/other" - if cacheMetaMatches(&badPkg, pkg, opts, files) { - t.Fatal("expected pkg mismatch") - } - badHeader := *meta - badHeader.HeaderHash = "bad" - if cacheMetaMatches(&badHeader, pkg, opts, files) { - t.Fatal("expected header mismatch") - } - shortFiles := *meta - shortFiles.Files = nil - if cacheMetaMatches(&shortFiles, pkg, opts, files) { - t.Fatal("expected file count mismatch") - } - fileMismatch := *meta - fileMismatch.Files = append([]cacheFile(nil), meta.Files...) - fileMismatch.Files[0].Size++ - if cacheMetaMatches(&fileMismatch, pkg, opts, files) { - t.Fatal("expected file metadata mismatch") - } - pkgNoRoot := &packages.Package{PkgPath: pkg.PkgPath} - if cacheMetaMatches(meta, pkgNoRoot, opts, files) { - t.Fatal("expected missing root files") - } - noRootHash := *meta - noRootHash.RootHash = "" - if cacheMetaMatches(&noRootHash, pkg, opts, files) { - t.Fatal("expected empty root hash mismatch") - } - missingRootPkg := &packages.Package{ - PkgPath: "example.com/meta", - GoFiles: []string{filepath.Join(tempDir, "missing.go")}, - } - if cacheMetaMatches(meta, missingRootPkg, opts, files) { - t.Fatal("expected root hash error") - } - badRoot := *meta - badRoot.RootHash = "bad" - if cacheMetaMatches(&badRoot, pkg, opts, files) { - t.Fatal("expected root hash mismatch") - } - emptyContent := *meta - emptyContent.ContentHash = "" - if cacheMetaMatches(&emptyContent, pkg, opts, files) { - t.Fatal("expected empty content hash mismatch") - } - - if cacheMetaMatches(meta, pkg, opts, []string{filepath.Join(tempDir, "missing.go")}) { - t.Fatal("expected buildCacheFiles error") - } -} - -func TestCacheMetaReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readCacheMeta("missing"); ok { - t.Fatal("expected cache meta miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readCacheMeta("bad-json"); ok { - t.Fatal("expected cache meta miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeCacheMeta("mkdir", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeCacheMeta("marshal", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeCacheMeta("create", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeCacheMeta("write", &cacheMeta{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeCacheMeta("rename", &cacheMeta{}) -} - -func TestManifestReadWriteErrors(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - if _, ok := readManifest("missing"); ok { - t.Fatal("expected manifest miss") - } - - osReadFile = func(string) ([]byte, error) { - return []byte("{bad json"), nil - } - if _, ok := readManifest("bad-json"); ok { - t.Fatal("expected manifest miss on invalid json") - } - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osMkdirAll = func(string, os.FileMode) error { return errors.New("mkdir") } - writeManifestFile("mkdir", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - jsonMarshal = func(any) ([]byte, error) { return nil, errors.New("marshal") } - writeManifestFile("marshal", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(string, string) (*os.File, error) { return nil, errors.New("create") } - writeManifestFile("create", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osCreateTemp = func(dir, pattern string) (*os.File, error) { - tmp, err := os.CreateTemp(dir, pattern) - if err != nil { - return nil, err - } - name := tmp.Name() - if err := tmp.Close(); err != nil { - return nil, err - } - return os.Open(name) - } - writeManifestFile("write", &cacheManifest{}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - osRename = func(string, string) error { return errors.New("rename") } - writeManifestFile("rename", &cacheManifest{}) -} - -func TestManifestKeyHelpers(t *testing.T) { - if got := manifestKeyFromManifest(nil); got != "" { - t.Fatalf("expected empty manifest key, got %q", got) - } - env := []string{"A=B"} - opts := &GenerateOptions{ - Tags: "tags", - PrefixOutputFile: "prefix", - Header: []byte("header"), - } - wd := t.TempDir() - patterns := []string{"./a", "./b"} - manifest := &cacheManifest{ - WD: runCacheScope(wd, patterns), - EnvHash: envHash(env), - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - } - got := manifestKeyFromManifest(manifest) - want := manifestKey(wd, env, patterns, opts) - if got != want { - t.Fatalf("manifest key mismatch: got %q, want %q", got, want) - } -} - -func TestReadManifestResultsPaths(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected no manifest") - } - - key := manifestKey(wd, env, patterns, opts) - invalid := &cacheManifest{Version: cacheVersion, WD: wd, EnvHash: "", Packages: nil} - writeManifestFile(key, invalid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected invalid manifest miss") - } - - file := writeTempFile(t, wd, "wire.go", "package app\n") - pkg := &packages.Package{ - PkgPath: "example.com/app", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, valid) - if _, ok := readManifestResults(wd, env, patterns, opts); ok { - t.Fatal("expected cache miss without content") - } - writeCache(contentHash, []byte("wire")) - if results, ok := readManifestResults(wd, env, patterns, opts); !ok || len(results) != 1 { - t.Fatalf("expected manifest cache hit, got ok=%v results=%d", ok, len(results)) - } -} - -func TestWriteManifestBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - - writeManifest(wd, env, patterns, opts, nil) - - writeManifest(wd, env, patterns, opts, []*packages.Package{nil}) - - writeManifest(wd, env, patterns, opts, []*packages.Package{{PkgPath: "example.com/empty"}}) - - missingFilePkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{filepath.Join(wd, "missing.go")}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{missingFilePkg}) - - conflictDir := t.TempDir() - fileA := writeTempFile(t, conflictDir, "a.go", "package a\n") - fileB := writeTempFile(t, t.TempDir(), "b.go", "package b\n") - conflictPkg := &packages.Package{ - PkgPath: "example.com/conflict", - GoFiles: []string{fileA, fileB}, - } - writeManifest(wd, env, patterns, opts, []*packages.Package{conflictPkg}) - - okFile := writeTempFile(t, wd, "ok.go", "package ok\n") - okPkg := &packages.Package{ - PkgPath: "example.com/ok", - GoFiles: []string{okFile}, - } - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", errors.New("cache key") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "", nil - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - cacheKeyForPackageFunc = func(*packages.Package, *GenerateOptions) (string, error) { - return "hash", nil - } - detectOutputDirFunc = func([]string) (string, error) { - return "", errors.New("output") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - detectOutputDirFunc = state.detectOutputDir - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - return nil, errors.New("build") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - call := 0 - buildCacheFilesFunc = func([]string) ([]cacheFile, error) { - call++ - if call > 1 { - return nil, errors.New("root") - } - return []cacheFile{{Path: okFile}}, nil - } - rootPackageFilesFunc = func(*packages.Package) []string { - return []string{okFile} - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - buildCacheFilesFunc = state.buildCacheFiles - hashFilesFunc = func([]string) (string, error) { - return "", errors.New("hash") - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - statCalls := 0 - osStat = func(name string) (os.FileInfo, error) { - statCalls++ - if statCalls > 3 { - return nil, errors.New("stat") - } - return state.osStat(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) - - restoreCacheHooks(state) - osTempDir = func() string { return tempDir } - readCalls := 0 - osReadFile = func(name string) ([]byte, error) { - readCalls++ - if readCalls > 2 { - return nil, errors.New("read") - } - return state.osReadFile(name) - } - writeManifest(wd, env, patterns, opts, []*packages.Package{okPkg}) -} - -func TestManifestValidationAndExtras(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - if manifestValid(nil) { - t.Fatal("expected nil manifest invalid") - } - if manifestValid(&cacheManifest{Version: "bad"}) { - t.Fatal("expected version mismatch") - } - if manifestValid(&cacheManifest{Version: cacheVersion}) { - t.Fatal("expected missing env hash") - } - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "valid.go", "package valid\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - valid := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/valid", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: nil, - } - if !manifestValid(valid) { - t.Fatal("expected valid manifest") - } - - invalidExtra := cloneManifest(valid) - invalidExtra.ExtraFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidExtra) { - t.Fatal("expected invalid extra files") - } - - extraMismatch := cloneManifest(valid) - extraMismatch.ExtraFiles = []cacheFile{files[0]} - extraMismatch.ExtraFiles[0].Size++ - if manifestValid(extraMismatch) { - t.Fatal("expected extra file metadata mismatch") - } - - invalidPkg := cloneManifest(valid) - invalidPkg.Packages[0].ContentHash = "" - if manifestValid(invalidPkg) { - t.Fatal("expected invalid content hash") - } - - invalidRoot := cloneManifest(valid) - invalidRoot.Packages[0].RootHash = "" - if manifestValid(invalidRoot) { - t.Fatal("expected invalid root hash") - } - - invalidFiles := cloneManifest(valid) - invalidFiles.Packages[0].Files = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidFiles) { - t.Fatal("expected invalid package files") - } - - fileMismatch := cloneManifest(valid) - fileMismatch.Packages[0].Files = []cacheFile{files[0]} - fileMismatch.Packages[0].Files[0].Size++ - if manifestValid(fileMismatch) { - t.Fatal("expected package file mismatch") - } - - invalidRootFiles := cloneManifest(valid) - invalidRootFiles.Packages[0].RootFiles = []cacheFile{{Path: filepath.Join(tempDir, "missing.go")}} - if manifestValid(invalidRootFiles) { - t.Fatal("expected invalid root files") - } - - rootMismatch := cloneManifest(valid) - rootMismatch.Packages[0].RootFiles = []cacheFile{files[0]} - rootMismatch.Packages[0].RootFiles[0].Size++ - if manifestValid(rootMismatch) { - t.Fatal("expected root file mismatch") - } - - emptyRoot := cloneManifest(valid) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - badHash := cloneManifest(valid) - badHash.Packages[0].RootHash = "bad" - if manifestValid(badHash) { - t.Fatal("expected root hash mismatch") - } - - if _, err := buildCacheFilesFromMeta([]cacheFile{{Path: filepath.Join(tempDir, "missing.go")}}); err == nil { - t.Fatal("expected buildCacheFilesFromMeta error") - } - - extraCachePathsFunc = func(string) []string { - return []string{file, file, filepath.Join(tempDir, "missing.go")} - } - extras := extraCacheFiles(tempDir) - if len(extras) != 1 { - t.Fatalf("expected 1 extra file, got %d", len(extras)) - } - - extraCachePathsFunc = func(string) []string { return nil } - if extras := extraCacheFiles(tempDir); extras != nil { - t.Fatal("expected nil extras") - } - - extraCachePathsFunc = func(string) []string { return []string{file, writeTempFile(t, tempDir, "go.sum", "sum\n")} } - if extras := extraCacheFiles(tempDir); len(extras) < 2 { - t.Fatalf("expected extras to include two files, got %v", extras) - } -} - -func TestExtraCachePaths(t *testing.T) { - tempDir := t.TempDir() - rootMod := writeTempFile(t, tempDir, "go.mod", "module example.com/root\n") - writeTempFile(t, tempDir, "go.sum", "sum\n") - nested := filepath.Join(tempDir, "nested", "dir") - if err := os.MkdirAll(nested, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - paths := extraCachePaths(nested) - if len(paths) < 2 { - t.Fatalf("expected extra cache paths, got %v", paths) - } - found := false - for _, path := range paths { - if path == rootMod { - found = true - break - } - } - if !found { - t.Fatalf("expected %s in paths: %v", rootMod, paths) - } - if got := sortedStrings(nil); got != nil { - t.Fatal("expected nil for empty sortedStrings") - } - if got := envHash(nil); got != "" { - t.Fatal("expected empty env hash") - } -} - -func TestRootPackageFiles(t *testing.T) { - if rootPackageFiles(nil) != nil { - t.Fatal("expected nil root files for nil package") - } - tempDir := t.TempDir() - compiled := writeTempFile(t, tempDir, "compiled.go", "package compiled\n") - pkg := &packages.Package{ - PkgPath: "example.com/compiled", - CompiledGoFiles: []string{compiled}, - } - got := rootPackageFiles(pkg) - if len(got) != 1 || got[0] != compiled { - t.Fatalf("unexpected compiled files: %v", got) - } -} - -func TestAddExtraCachePath(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "go.mod", "module example.com\n") - var paths []string - seen := make(map[string]struct{}) - addExtraCachePath(&paths, seen, file) - addExtraCachePath(&paths, seen, file) - if len(paths) != 1 { - t.Fatalf("expected 1 path, got %d", len(paths)) - } - addExtraCachePath(&paths, seen, filepath.Join(tempDir, "missing.go")) - if len(paths) != 1 { - t.Fatalf("unexpected extra path append: %v", paths) - } -} - -func TestManifestValidHookBranches(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - file := writeTempFile(t, tempDir, "hook.go", "package hook\n") - files, err := buildCacheFiles([]string{file}) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootHash, err := hashFiles([]string{file}) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - base := &cacheManifest{ - Version: cacheVersion, - WD: tempDir, - EnvHash: "env", - Packages: []manifestPackage{{PkgPath: "example.com/hook", Files: files, RootFiles: files, ContentHash: "hash", RootHash: rootHash}}, - ExtraFiles: []cacheFile{files[0]}, - } - - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == files[0].Path { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(base) { - t.Fatal("expected extra file length mismatch") - } - - restoreCacheHooks(state) - emptyRoot := cloneManifest(base) - emptyRoot.Packages[0].RootFiles = nil - if manifestValid(emptyRoot) { - t.Fatal("expected empty root files") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return nil, errors.New("pkg files") - } - return buildCacheFilesFromMeta(in) - } - noExtra := cloneManifest(base) - noExtra.ExtraFiles = nil - if manifestValid(noExtra) { - t.Fatal("expected pkg files error") - } - - restoreCacheHooks(state) - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - if len(in) == 1 && in[0].Path == file { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected pkg files length mismatch") - } - - restoreCacheHooks(state) - call := 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return nil, errors.New("root files") - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files error") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files length mismatch") - } - - restoreCacheHooks(state) - call = 0 - buildCacheFilesFromMetaFunc = func(in []cacheFile) ([]cacheFile, error) { - call++ - if call == 2 { - return []cacheFile{{Path: file, Size: files[0].Size + 1}}, nil - } - return buildCacheFilesFromMeta(in) - } - if manifestValid(noExtra) { - t.Fatal("expected root files mismatch") - } -} diff --git a/internal/wire/cache_generate_test.go b/internal/wire/cache_generate_test.go deleted file mode 100644 index d009f73..0000000 --- a/internal/wire/cache_generate_test.go +++ /dev/null @@ -1,100 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "sort" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateUsesManifestCache(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - wd := t.TempDir() - file := filepath.Join(wd, "provider.go") - if err := os.WriteFile(file, []byte("package p\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - env := []string{"A=B"} - patterns := []string{"./..."} - opts := &GenerateOptions{} - key := manifestKey(wd, env, patterns, opts) - - pkg := &packages.Package{ - PkgPath: "example.com/p", - GoFiles: []string{file}, - } - files := packageFiles(pkg) - sort.Strings(files) - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - t.Fatalf("contentHashForFiles error: %v", err) - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - t.Fatalf("buildCacheFiles error: %v", err) - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFiles(rootFiles) - if err != nil { - t.Fatalf("buildCacheFiles root error: %v", err) - } - rootHash, err := hashFiles(rootFiles) - if err != nil { - t.Fatalf("hashFiles error: %v", err) - } - - manifest := &cacheManifest{ - Version: cacheVersion, - WD: wd, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: sortedStrings(patterns), - Packages: []manifestPackage{ - { - PkgPath: pkg.PkgPath, - OutputPath: filepath.Join(wd, "wire_gen.go"), - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }, - }, - } - writeManifestFile(key, manifest) - writeCache(contentHash, []byte("wire")) - - results, errs := Generate(context.Background(), wd, env, patterns, opts) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(results) != 1 || string(results[0].Content) != "wire" { - t.Fatalf("unexpected cached results: %+v", results) - } -} diff --git a/internal/wire/cache_key.go b/internal/wire/cache_key.go deleted file mode 100644 index f22c6c0..0000000 --- a/internal/wire/cache_key.go +++ /dev/null @@ -1,352 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "runtime" - "sort" - "sync" - - "golang.org/x/tools/go/packages" -) - -// cacheVersion is the schema/version identifier for cache entries. -const cacheVersion = "wire-cache-v3" - -// cacheFile captures file metadata used to validate cached content. -type cacheFile struct { - Path string `json:"path"` - Size int64 `json:"size"` - ModTime int64 `json:"mod_time"` -} - -// cacheMeta tracks inputs and outputs for a single package cache entry. -type cacheMeta struct { - Version string `json:"version"` - PkgPath string `json:"pkg_path"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootHash string `json:"root_hash"` -} - -// cacheKeyForPackage returns the content hash for a package, if cacheable. -func cacheKeyForPackage(pkg *packages.Package, opts *GenerateOptions) (string, error) { - files := packageFiles(pkg) - if len(files) == 0 { - return "", nil - } - sort.Strings(files) - metaKey := cacheMetaKey(pkg, opts) - if meta, ok := readCacheMeta(metaKey); ok { - if cacheMetaMatches(meta, pkg, opts, files) { - return meta.ContentHash, nil - } - } - contentHash, err := contentHashForFiles(pkg, opts, files) - if err != nil { - return "", err - } - rootFiles := rootPackageFiles(pkg) - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil { - return "", err - } - metaFiles, err := buildCacheFiles(files) - if err != nil { - return "", err - } - meta := &cacheMeta{ - Version: cacheVersion, - PkgPath: pkg.PkgPath, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - Files: metaFiles, - ContentHash: contentHash, - RootHash: rootHash, - } - writeCacheMeta(metaKey, meta) - return contentHash, nil -} - -// packageFiles returns the transitive Go files for a package graph. -func packageFiles(root *packages.Package) []string { - seen := make(map[string]struct{}) - var files []string - stack := []*packages.Package{root} - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil { - continue - } - if _, ok := seen[p.PkgPath]; ok { - continue - } - seen[p.PkgPath] = struct{}{} - if len(p.CompiledGoFiles) > 0 { - files = append(files, p.CompiledGoFiles...) - } else if len(p.GoFiles) > 0 { - files = append(files, p.GoFiles...) - } - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return files -} - -// cacheMetaKey builds the key for a package's cache metadata entry. -func cacheMetaKey(pkg *packages.Package, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkg.PkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// cacheMetaPath returns the on-disk path for a cache metadata key. -func cacheMetaPath(key string) string { - return filepath.Join(cacheDir(), key+".json") -} - -// readCacheMeta loads a cached metadata entry if it exists. -func readCacheMeta(key string) (*cacheMeta, bool) { - data, err := osReadFile(cacheMetaPath(key)) - if err != nil { - return nil, false - } - var meta cacheMeta - if err := jsonUnmarshal(data, &meta); err != nil { - return nil, false - } - return &meta, true -} - -// writeCacheMeta persists cache metadata to disk. -func writeCacheMeta(key string, meta *cacheMeta) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(meta) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".meta-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheMetaPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheMetaMatches reports whether metadata matches the current package inputs. -func cacheMetaMatches(meta *cacheMeta, pkg *packages.Package, opts *GenerateOptions, files []string) bool { - if meta.Version != cacheVersion { - return false - } - if meta.PkgPath != pkg.PkgPath || meta.Tags != opts.Tags || meta.Prefix != opts.PrefixOutputFile { - return false - } - if meta.HeaderHash != headerHash(opts.Header) { - return false - } - if len(meta.Files) != len(files) { - return false - } - current, err := buildCacheFiles(files) - if err != nil { - return false - } - for i := range meta.Files { - if meta.Files[i] != current[i] { - return false - } - } - rootFiles := rootPackageFiles(pkg) - if len(rootFiles) == 0 || meta.RootHash == "" { - return false - } - sort.Strings(rootFiles) - rootHash, err := hashFiles(rootFiles) - if err != nil || rootHash != meta.RootHash { - return false - } - return meta.ContentHash != "" -} - -// buildCacheFiles converts file paths into cache metadata entries. -func buildCacheFiles(files []string) ([]cacheFile, error) { - return buildCacheFilesWithStats(files, func(path string) (cacheFile, error) { - info, err := osStat(path) - if err != nil { - return cacheFile{}, err - } - return cacheFile{ - Path: filepath.Clean(path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }, nil - }) -} - -func buildCacheFilesWithStats[T any](items []T, stat func(T) (cacheFile, error)) ([]cacheFile, error) { - if len(items) == 0 { - return nil, nil - } - if len(items) == 1 { - file, err := stat(items[0]) - if err != nil { - return nil, err - } - return []cacheFile{file}, nil - } - out := make([]cacheFile, len(items)) - workers := runtime.GOMAXPROCS(0) - if workers < 4 { - workers = 4 - } - if workers > len(items) { - workers = len(items) - } - var ( - wg sync.WaitGroup - mu sync.Mutex - firstErr error - indexCh = make(chan int, len(items)) - ) - for i := range items { - indexCh <- i - } - close(indexCh) - wg.Add(workers) - for i := 0; i < workers; i++ { - go func() { - defer wg.Done() - for i := range indexCh { - file, err := stat(items[i]) - if err != nil { - mu.Lock() - if firstErr == nil { - firstErr = err - } - mu.Unlock() - continue - } - out[i] = file - } - }() - } - wg.Wait() - if firstErr != nil { - return nil, firstErr - } - return out, nil -} - -// headerHash returns a stable hash of the generated header content. -func headerHash(header []byte) string { - if len(header) == 0 { - return "" - } - sum := sha256.Sum256(header) - return fmt.Sprintf("%x", sum[:]) -} - -// contentHashForFiles hashes the current package inputs using file paths. -func contentHashForFiles(pkg *packages.Package, opts *GenerateOptions, files []string) (string, error) { - return contentHashForPaths(pkg.PkgPath, opts, files) -} - -// contentHashForPaths hashes the provided file contents and options. -func contentHashForPaths(pkgPath string, opts *GenerateOptions, files []string) (string, error) { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} - -// rootPackageFiles returns the direct Go files for the root package. -func rootPackageFiles(pkg *packages.Package) []string { - if pkg == nil { - return nil - } - if len(pkg.CompiledGoFiles) > 0 { - return append([]string(nil), pkg.CompiledGoFiles...) - } - if len(pkg.GoFiles) > 0 { - return append([]string(nil), pkg.GoFiles...) - } - return nil -} - -// hashFiles returns a combined content hash for the provided paths. -func hashFiles(files []string) (string, error) { - if len(files) == 0 { - return "", nil - } - h := sha256.New() - for _, name := range files { - h.Write([]byte(name)) - h.Write([]byte{0}) - data, err := osReadFile(name) - if err != nil { - return "", err - } - h.Write(data) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)), nil -} diff --git a/internal/wire/cache_manifest.go b/internal/wire/cache_manifest.go deleted file mode 100644 index 57be68b..0000000 --- a/internal/wire/cache_manifest.go +++ /dev/null @@ -1,393 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -// cacheManifest stores per-run cache metadata for generated packages. -type cacheManifest struct { - Version string `json:"version"` - WD string `json:"wd"` - Tags string `json:"tags"` - Prefix string `json:"prefix"` - HeaderHash string `json:"header_hash"` - EnvHash string `json:"env_hash"` - Patterns []string `json:"patterns"` - Packages []manifestPackage `json:"packages"` - ExtraFiles []cacheFile `json:"extra_files"` -} - -// manifestPackage captures cached output for a single package. -type manifestPackage struct { - PkgPath string `json:"pkg_path"` - OutputPath string `json:"output_path"` - Files []cacheFile `json:"files"` - ContentHash string `json:"content_hash"` - RootFiles []cacheFile `json:"root_files"` - RootHash string `json:"root_hash"` -} - -var extraCachePathsFunc = extraCachePaths - -// readManifestResults loads cached generation results if still valid. -func readManifestResults(wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { - key := manifestKey(wd, env, patterns, opts) - manifest, ok := readManifest(key) - if !ok { - return nil, false - } - if !manifestValid(manifest) { - return nil, false - } - results := make([]GenerateResult, 0, len(manifest.Packages)) - for _, pkg := range manifest.Packages { - content, ok := readCache(pkg.ContentHash) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: pkg.PkgPath, - OutputPath: pkg.OutputPath, - Content: content, - }) - } - return results, true -} - -// writeManifest persists cache metadata for a successful run. -func writeManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) { - if len(pkgs) == 0 { - return - } - key := manifestKey(wd, env, patterns, opts) - scope := runCacheScope(wd, patterns) - manifest := &cacheManifest{ - Version: cacheVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - } - manifest.ExtraFiles = extraCacheFiles(wd) - for _, pkg := range pkgs { - if pkg == nil { - continue - } - files := packageFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - contentHash, err := cacheKeyForPackageFunc(pkg, opts) - if err != nil || contentHash == "" { - continue - } - outDir, err := detectOutputDirFunc(pkg.GoFiles) - if err != nil { - continue - } - outputPath := filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - metaFiles, err := buildCacheFilesFunc(files) - if err != nil { - continue - } - rootFiles := rootPackageFilesFunc(pkg) - sort.Strings(rootFiles) - rootMeta, err := buildCacheFilesFunc(rootFiles) - if err != nil { - continue - } - rootHash, err := hashFilesFunc(rootFiles) - if err != nil { - continue - } - manifest.Packages = append(manifest.Packages, manifestPackage{ - PkgPath: pkg.PkgPath, - OutputPath: outputPath, - Files: metaFiles, - ContentHash: contentHash, - RootFiles: rootMeta, - RootHash: rootHash, - }) - } - writeManifestFile(key, manifest) -} - -// manifestKey builds the cache key for a given run configuration. -func manifestKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(runCacheScope(wd, patterns))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// manifestKeyFromManifest rebuilds the cache key from stored metadata. -func manifestKeyFromManifest(manifest *cacheManifest) string { - if manifest == nil { - return "" - } - h := sha256.New() - h.Write([]byte(cacheVersion)) - h.Write([]byte{0}) - h.Write([]byte(filepath.Clean(manifest.WD))) - h.Write([]byte{0}) - h.Write([]byte(manifest.EnvHash)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Tags)) - h.Write([]byte{0}) - h.Write([]byte(manifest.Prefix)) - h.Write([]byte{0}) - h.Write([]byte(manifest.HeaderHash)) - h.Write([]byte{0}) - for _, p := range sortedStrings(manifest.Patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -// readManifest loads the cached manifest by key. -func readManifest(key string) (*cacheManifest, bool) { - data, err := osReadFile(cacheManifestPath(key)) - if err != nil { - return nil, false - } - var manifest cacheManifest - if err := jsonUnmarshal(data, &manifest); err != nil { - return nil, false - } - return &manifest, true -} - -// writeManifestFile writes the manifest to disk. -func writeManifestFile(key string, manifest *cacheManifest) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - data, err := jsonMarshal(manifest) - if err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".manifest-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - path := cacheManifestPath(key) - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} - -// cacheManifestPath returns the on-disk path for a manifest key. -func cacheManifestPath(key string) string { - return filepath.Join(cacheDir(), key+".manifest.json") -} - -// manifestValid reports whether the manifest still matches current inputs. -func manifestValid(manifest *cacheManifest) bool { - if manifest == nil || manifest.Version != cacheVersion { - return false - } - if manifest.EnvHash == "" || len(manifest.Packages) == 0 { - return false - } - if len(manifest.ExtraFiles) > 0 { - current, err := buildCacheFilesFromMetaFunc(manifest.ExtraFiles) - if err != nil { - return false - } - if len(current) != len(manifest.ExtraFiles) { - return false - } - for i := range manifest.ExtraFiles { - if manifest.ExtraFiles[i] != current[i] { - return false - } - } - } - for i := range manifest.Packages { - pkg := manifest.Packages[i] - if pkg.ContentHash == "" { - return false - } - if len(pkg.RootFiles) == 0 || pkg.RootHash == "" { - return false - } - current, err := buildCacheFilesFromMetaFunc(pkg.Files) - if err != nil { - return false - } - if len(current) != len(pkg.Files) { - return false - } - for j := range pkg.Files { - if pkg.Files[j] != current[j] { - return false - } - } - rootCurrent, err := buildCacheFilesFromMetaFunc(pkg.RootFiles) - if err != nil { - return false - } - if len(rootCurrent) != len(pkg.RootFiles) { - return false - } - for j := range pkg.RootFiles { - if pkg.RootFiles[j] != rootCurrent[j] { - return false - } - } - rootPaths := make([]string, 0, len(pkg.RootFiles)) - for _, file := range pkg.RootFiles { - rootPaths = append(rootPaths, file.Path) - } - sort.Strings(rootPaths) - rootHash, err := hashFiles(rootPaths) - if err != nil || rootHash != pkg.RootHash { - return false - } - } - return true -} - -// buildCacheFilesFromMeta re-stats files to compare metadata. -func buildCacheFilesFromMeta(files []cacheFile) ([]cacheFile, error) { - return buildCacheFilesWithStats(files, func(file cacheFile) (cacheFile, error) { - info, err := osStat(file.Path) - if err != nil { - return cacheFile{}, err - } - return cacheFile{ - Path: filepath.Clean(file.Path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }, nil - }) -} - -// extraCacheFiles returns Go module/workspace files affecting builds. -func extraCacheFiles(wd string) []cacheFile { - paths := extraCachePathsFunc(wd) - if len(paths) == 0 { - return nil - } - out := make([]cacheFile, 0, len(paths)) - seen := make(map[string]struct{}) - for _, path := range paths { - path = filepath.Clean(path) - if _, ok := seen[path]; ok { - continue - } - info, err := osStat(path) - if err != nil { - continue - } - seen[path] = struct{}{} - out = append(out, cacheFile{ - Path: path, - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - }) - } - sort.Slice(out, func(i, j int) bool { - return out[i].Path < out[j].Path - }) - return out -} - -// extraCachePaths finds go.mod/go.sum/go.work files for a working dir. -func extraCachePaths(wd string) []string { - var paths []string - dir := filepath.Clean(wd) - seen := make(map[string]struct{}) - for { - for _, name := range []string{"go.work", "go.work.sum", "go.mod", "go.sum"} { - full := filepath.Join(dir, name) - addExtraCachePath(&paths, seen, full) - } - parent := filepath.Dir(dir) - if parent == dir { - break - } - dir = parent - } - return paths -} - -// addExtraCachePath appends an existing file if it has not been seen. -func addExtraCachePath(paths *[]string, seen map[string]struct{}, full string) { - if _, ok := seen[full]; ok { - return - } - if _, err := osStat(full); err != nil { - return - } - *paths = append(*paths, full) - seen[full] = struct{}{} -} - -// sortedStrings returns a sorted copy of the input slice. -func sortedStrings(values []string) []string { - if len(values) == 0 { - return nil - } - out := append([]string(nil), values...) - sort.Strings(out) - return out -} - -// envHash returns a stable hash of environment variables. -func envHash(env []string) string { - if len(env) == 0 { - return "" - } - sorted := sortedStrings(env) - h := sha256.New() - for _, v := range sorted { - h.Write([]byte(v)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} diff --git a/internal/wire/cache_scope.go b/internal/wire/cache_scope.go deleted file mode 100644 index fe161a7..0000000 --- a/internal/wire/cache_scope.go +++ /dev/null @@ -1,69 +0,0 @@ -package wire - -import ( - "path/filepath" - "sort" - "strings" -) - -func packageCacheScope(wd string) string { - if root := findModuleRoot(wd); root != "" { - return filepath.Clean(root) - } - return filepath.Clean(wd) -} - -func runCacheScope(wd string, patterns []string) string { - scopeRoot := packageCacheScope(wd) - normalized := normalizePatternsForScope(wd, scopeRoot, patterns) - if len(normalized) == 0 { - return scopeRoot - } - return scopeRoot + "\n" + strings.Join(normalized, "\n") -} - -func normalizePatternsForScope(wd string, scopeRoot string, patterns []string) []string { - if len(patterns) == 0 { - return nil - } - out := make([]string, 0, len(patterns)) - for _, pattern := range patterns { - out = append(out, normalizePatternForScope(wd, scopeRoot, pattern)) - } - sort.Strings(out) - return out -} - -func normalizePatternForScope(wd string, scopeRoot string, pattern string) string { - if pattern == "" { - return pattern - } - if filepath.IsAbs(pattern) || strings.HasPrefix(pattern, ".") { - abs := pattern - if !filepath.IsAbs(abs) { - abs = filepath.Join(wd, pattern) - } - abs = filepath.Clean(abs) - if scopeRoot != "" { - if rel, ok := pathWithinRoot(scopeRoot, abs); ok { - if rel == "." { - return "." - } - return filepath.ToSlash(rel) - } - } - return filepath.ToSlash(abs) - } - return pattern -} - -func pathWithinRoot(root string, path string) (string, bool) { - rel, err := filepath.Rel(filepath.Clean(root), filepath.Clean(path)) - if err != nil { - return "", false - } - if rel == ".." || strings.HasPrefix(rel, ".."+string(filepath.Separator)) { - return "", false - } - return rel, true -} diff --git a/internal/wire/cache_scope_test.go b/internal/wire/cache_scope_test.go deleted file mode 100644 index 9cc518b..0000000 --- a/internal/wire/cache_scope_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package wire - -import ( - "path/filepath" - "testing" -) - -func TestRunScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") - wireDir := filepath.Join(root, "wire") - writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") - - env := []string{"GOOS=darwin"} - opts := &GenerateOptions{Tags: "wireinject", PrefixOutputFile: "gen_"} - - rootKey := manifestKey(root, env, []string{"./wire"}, opts) - subdirKey := manifestKey(wireDir, env, []string{"."}, opts) - if rootKey != subdirKey { - t.Fatalf("manifestKey mismatch: root=%q subdir=%q", rootKey, subdirKey) - } - - rootIncrementalKey := incrementalManifestSelectorKey(root, env, []string{"./wire"}, opts) - subdirIncrementalKey := incrementalManifestSelectorKey(wireDir, env, []string{"."}, opts) - if rootIncrementalKey != subdirIncrementalKey { - t.Fatalf("incrementalManifestSelectorKey mismatch: root=%q subdir=%q", rootIncrementalKey, subdirIncrementalKey) - } -} - -func TestPackageScopedKeysIgnoreEquivalentWorkingDirectories(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.22\n") - wireDir := filepath.Join(root, "wire") - writeFile(t, filepath.Join(wireDir, "wire.go"), "package wire\n") - - rootFingerprintKey := incrementalFingerprintKey(root, "wireinject", "example.com/app/wire") - subdirFingerprintKey := incrementalFingerprintKey(wireDir, "wireinject", "example.com/app/wire") - if rootFingerprintKey != subdirFingerprintKey { - t.Fatalf("incrementalFingerprintKey mismatch: root=%q subdir=%q", rootFingerprintKey, subdirFingerprintKey) - } - - rootSummaryKey := incrementalSummaryKey(root, "wireinject", "example.com/app/wire") - subdirSummaryKey := incrementalSummaryKey(wireDir, "wireinject", "example.com/app/wire") - if rootSummaryKey != subdirSummaryKey { - t.Fatalf("incrementalSummaryKey mismatch: root=%q subdir=%q", rootSummaryKey, subdirSummaryKey) - } - - rootGraphKey := incrementalGraphKey(root, "wireinject", []string{"example.com/app/wire"}) - subdirGraphKey := incrementalGraphKey(wireDir, "wireinject", []string{"example.com/app/wire"}) - if rootGraphKey != subdirGraphKey { - t.Fatalf("incrementalGraphKey mismatch: root=%q subdir=%q", rootGraphKey, subdirGraphKey) - } - - rootSessionKey := sessionKey(root, []string{"GOOS=darwin"}, "wireinject") - subdirSessionKey := sessionKey(wireDir, []string{"GOOS=darwin"}, "wireinject") - if rootSessionKey != subdirSessionKey { - t.Fatalf("sessionKey mismatch: root=%q subdir=%q", rootSessionKey, subdirSessionKey) - } -} diff --git a/internal/wire/cache_store.go b/internal/wire/cache_store.go deleted file mode 100644 index 0c959cf..0000000 --- a/internal/wire/cache_store.go +++ /dev/null @@ -1,77 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "errors" - "io/fs" - "path/filepath" -) - -// cacheDir returns the base directory for Wire cache files. -func cacheDir() string { - return filepath.Join(osTempDir(), "wire-cache") -} - -// CacheDir returns the directory used for Wire's cache. -func CacheDir() string { - return cacheDir() -} - -// ClearCache removes all cached data. -func ClearCache() error { - clearIncrementalSessions() - return osRemoveAll(cacheDir()) -} - -// cachePath builds the on-disk path for a cached content hash. -func cachePath(key string) string { - return filepath.Join(cacheDir(), key+".bin") -} - -// readCache reads a cached content blob by key. -func readCache(key string) ([]byte, bool) { - data, err := osReadFile(cachePath(key)) - if err != nil { - return nil, false - } - return data, true -} - -// writeCache persists a content blob for the provided cache key. -func writeCache(key string, content []byte) { - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - path := cachePath(key) - tmp, err := osCreateTemp(dir, key+".tmp-") - if err != nil { - return - } - _, writeErr := tmp.Write(content) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), path); err != nil { - if errors.Is(err, fs.ErrExist) { - osRemove(tmp.Name()) - return - } - osRemove(tmp.Name()) - } -} diff --git a/internal/wire/cache_test.go b/internal/wire/cache_test.go deleted file mode 100644 index 6ffb20a..0000000 --- a/internal/wire/cache_test.go +++ /dev/null @@ -1,387 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestCacheInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - first, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("first Generate errors: %v", errs) - } - if len(first) != 1 || len(first[0].Content) == 0 { - t.Fatalf("first Generate returned unexpected result: %+v", first) - } - - pkgs, _, errs := load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load failed: %v", errs) - } - key, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - if cached, ok := readCache(key); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after first Generate") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - second, errs := Generate(ctx, root, env, []string{"./app"}, opts) - if len(errs) > 0 { - t.Fatalf("second Generate errors: %v", errs) - } - if len(second) != 1 || len(second[0].Content) == 0 { - t.Fatalf("second Generate returned unexpected result: %+v", second) - } - pkgs, _, errs = load(ctx, root, env, opts.Tags, []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("reload failed: %v", errs) - } - key2, err := cacheKeyForPackage(pkgs[0], opts) - if err != nil { - t.Fatalf("cacheKeyForPackage after update failed: %v", err) - } - if key2 == key { - t.Fatal("expected cache key to change after source update") - } - if !IncrementalEnabled(ctx, env) { - if cached, ok := readCache(key2); !ok || len(cached) == 0 { - t.Fatal("expected cache entry after second Generate") - } - } -} - -func TestManifestInvalidation(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - depPath := filepath.Join(root, "dep", "dep.go") - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, depPath, strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"goodbye\"", - "}", - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after source update") - } -} - -func TestManifestInvalidationGoMod(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - goModPath := filepath.Join(root, "go.mod") - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - writeFile(t, goModPath, strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0 // updated", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after go.mod update") - } -} - -func TestManifestInvalidationSameTimestamp(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - prevTmp := os.Getenv("TMPDIR") - if err := os.Setenv("TMPDIR", t.TempDir()); err != nil { - t.Fatalf("Setenv TMPDIR failed: %v", err) - } - t.Cleanup(func() { - os.Setenv("TMPDIR", prevTmp) - }) - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - wirePath := filepath.Join(root, "app", "wire.go") - writeFile(t, wirePath, strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() string {", - "\twire.Build(dep.ProvideMessage)", - "\treturn \"\"", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "func ProvideMessage() string {", - "\treturn \"hello\"", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - opts := &GenerateOptions{} - - if _, errs := Generate(ctx, root, env, []string{"./app"}, opts); len(errs) > 0 { - t.Fatalf("Generate errors: %v", errs) - } - - key := manifestKey(root, env, []string{"./app"}, opts) - manifest, ok := readManifest(key) - if !ok { - t.Fatal("expected manifest after Generate") - } - if !manifestValid(manifest) { - t.Fatal("expected manifest to be valid") - } - - info, err := os.Stat(wirePath) - if err != nil { - t.Fatalf("Stat failed: %v", err) - } - originalMod := info.ModTime() - - original, err := os.ReadFile(wirePath) - if err != nil { - t.Fatalf("ReadFile failed: %v", err) - } - updated := strings.Replace(string(original), "ProvideMessage", "ProvideMassage", 1) - if len(updated) != len(original) { - t.Fatalf("expected updated content to keep length; got %d vs %d", len(updated), len(original)) - } - if err := os.WriteFile(wirePath, []byte(updated), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - if err := os.Chtimes(wirePath, originalMod, originalMod); err != nil { - t.Fatalf("Chtimes failed: %v", err) - } - - if manifestValid(manifest) { - t.Fatal("expected manifest to be invalid after same-timestamp content update") - } -} diff --git a/internal/wire/generate_package.go b/internal/wire/generate_package.go deleted file mode 100644 index 01d3d20..0000000 --- a/internal/wire/generate_package.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2018 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "errors" - "fmt" - "go/format" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -// generateForPackage runs Wire code generation for a single package. -func generateForPackage(ctx context.Context, pkg *packages.Package, loader *lazyLoader, opts *GenerateOptions) GenerateResult { - if opts == nil { - opts = &GenerateOptions{} - } - pkgStart := time.Now() - res := GenerateResult{ - PkgPath: pkg.PkgPath, - } - dirStart := time.Now() - outDir, err := detectOutputDir(pkg.GoFiles) - logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - cacheKey, err := cacheKeyForPackage(pkg, opts) - if err != nil { - res.Errs = append(res.Errs, err) - return res - } - if cacheKey != "" && !bypassPackageCache(ctx) { - cacheHitStart := time.Now() - if cached, ok := readCache(cacheKey); ok { - res.Content = cached - logTiming(ctx, "generate.package."+pkg.PkgPath+".cache_hit", cacheHitStart) - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res - } - } - oc := newObjectCache([]*packages.Package{pkg}, loader) - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - res.Errs = append(res.Errs, errs...) - return res - } else if loaded != nil { - pkg = loaded - } - g := newGen(pkg) - injectorStart := time.Now() - injectorFiles, errs := generateInjectors(oc, g, pkg) - logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) - if len(errs) > 0 { - res.Errs = errs - return res - } - copyStart := time.Now() - copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) - logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) - frameStart := time.Now() - goSrc := g.frame(opts.Tags) - logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) - if len(opts.Header) > 0 { - goSrc = append(opts.Header, goSrc...) - } - formatStart := time.Now() - fmtSrc, err := format.Source(goSrc) - logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) - if err != nil { - // This is likely a bug from a poorly generated source file. - // Add an error but also the unformatted source. - res.Errs = append(res.Errs, err) - } else { - goSrc = fmtSrc - } - res.Content = goSrc - if cacheKey != "" && len(res.Errs) == 0 { - writeCache(cacheKey, res.Content) - } - logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) - return res -} - -// allGeneratedOK reports whether every package result succeeded. -func allGeneratedOK(results []GenerateResult) bool { - if len(results) == 0 { - return false - } - for _, res := range results { - if len(res.Errs) > 0 { - return false - } - } - return true -} - -// detectOutputDir returns a shared directory for the provided file paths. -func detectOutputDir(paths []string) (string, error) { - if len(paths) == 0 { - return "", errors.New("no files to derive output directory from") - } - dir := filepath.Dir(paths[0]) - for _, p := range paths[1:] { - if dir2 := filepath.Dir(p); dir2 != dir { - return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) - } - } - return dir, nil -} diff --git a/internal/wire/generate_package_test.go b/internal/wire/generate_package_test.go deleted file mode 100644 index 51b15b1..0000000 --- a/internal/wire/generate_package_test.go +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestGenerateForPackageOptionAndDetectErrors(t *testing.T) { - res := generateForPackage(context.Background(), &packages.Package{PkgPath: "example.com/empty"}, nil, nil) - if len(res.Errs) == 0 { - t.Fatal("expected error for empty package") - } - if _, err := detectOutputDir(nil); err == nil { - t.Fatal("expected detectOutputDir error") - } -} - -func TestGenerateForPackageCacheKeyError(t *testing.T) { - tempDir := t.TempDir() - missing := filepath.Join(tempDir, "missing.go") - pkg := &packages.Package{ - PkgPath: "example.com/missing", - GoFiles: []string{missing}, - } - res := generateForPackage(context.Background(), pkg, nil, &GenerateOptions{}) - if len(res.Errs) == 0 { - t.Fatal("expected cache key error") - } -} - -func TestGenerateForPackageCacheHit(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - file := writeTempFile(t, tempDir, "hit.go", "package hit\n") - pkg := &packages.Package{ - PkgPath: "example.com/hit", - GoFiles: []string{file}, - } - opts := &GenerateOptions{} - key, err := cacheKeyForPackage(pkg, opts) - if err != nil || key == "" { - t.Fatalf("cacheKeyForPackage failed: %v", err) - } - writeCache(key, []byte("cached")) - res := generateForPackage(context.Background(), pkg, nil, opts) - if string(res.Content) != "cached" { - t.Fatalf("expected cached content, got %q", res.Content) - } -} - -func TestGenerateForPackageFormatError(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - tempDir := t.TempDir() - osTempDir = func() string { return tempDir } - - repoRoot := mustRepoRoot(t) - writeTempFile(t, tempDir, "go.mod", strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - appDir := filepath.Join(tempDir, "app") - if err := os.MkdirAll(appDir, 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - writeTempFile(t, appDir, "wire.go", strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import \"github.com/goforj/wire\"", - "", - "func Init() string {", - "\twire.Build(NewMessage)", - "\treturn \"\"", - "}", - "", - "func NewMessage() string { return \"ok\" }", - "", - }, "\n")) - - ctx := context.Background() - env := append(os.Environ(), "GOWORK=off") - pkgs, loader, errs := load(ctx, tempDir, env, "", []string{"./app"}) - if len(errs) > 0 || len(pkgs) != 1 { - t.Fatalf("load errors: %v", errs) - } - opts := &GenerateOptions{Header: []byte("invalid")} - res := generateForPackage(ctx, pkgs[0], loader, opts) - if len(res.Errs) == 0 { - t.Fatal("expected format.Source error") - } -} - -func TestAllGeneratedOK(t *testing.T) { - if allGeneratedOK(nil) { - t.Fatal("expected empty results to be false") - } - if allGeneratedOK([]GenerateResult{{Errs: []error{context.DeadlineExceeded}}}) { - t.Fatal("expected errors to be false") - } - if !allGeneratedOK([]GenerateResult{{}}) { - t.Fatal("expected success results to be true") - } -} diff --git a/internal/wire/incremental.go b/internal/wire/incremental.go deleted file mode 100644 index 007027b..0000000 --- a/internal/wire/incremental.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "strconv" - "strings" -) - -const IncrementalEnvVar = "WIRE_INCREMENTAL" - -type incrementalKey struct{} -type incrementalColdBootstrapKey struct{} - -// WithIncremental overrides incremental-mode resolution for the provided -// context. This takes precedence over the environment variable. -func WithIncremental(ctx context.Context, enabled bool) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, incrementalKey{}, enabled) -} - -func withIncrementalColdBootstrap(ctx context.Context, enabled bool) context.Context { - if ctx == nil { - ctx = context.Background() - } - return context.WithValue(ctx, incrementalColdBootstrapKey{}, enabled) -} - -// IncrementalEnabled reports whether incremental mode is enabled for the -// current operation. A context override takes precedence over env. -func IncrementalEnabled(ctx context.Context, env []string) bool { - if ctx != nil { - if v := ctx.Value(incrementalKey{}); v != nil { - if enabled, ok := v.(bool); ok { - return enabled - } - } - } - raw, ok := lookupEnv(env, IncrementalEnvVar) - if !ok { - return false - } - enabled, err := strconv.ParseBool(strings.TrimSpace(raw)) - if err != nil { - return false - } - return enabled -} - -func incrementalColdBootstrapEnabled(ctx context.Context) bool { - if ctx == nil { - return false - } - if v := ctx.Value(incrementalColdBootstrapKey{}); v != nil { - if enabled, ok := v.(bool); ok { - return enabled - } - } - return false -} - -func lookupEnv(env []string, key string) (string, bool) { - prefix := key + "=" - for i := len(env) - 1; i >= 0; i-- { - if strings.HasPrefix(env[i], prefix) { - return strings.TrimPrefix(env[i], prefix), true - } - } - return "", false -} diff --git a/internal/wire/incremental_bench_test.go b/internal/wire/incremental_bench_test.go deleted file mode 100644 index b300c4f..0000000 --- a/internal/wire/incremental_bench_test.go +++ /dev/null @@ -1,1495 +0,0 @@ -package wire - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sort" - "strconv" - "strings" - "testing" - "time" - "unicode/utf8" -) - -const ( - largeBenchmarkTestPackageCount = 24 - largeBenchmarkHelperCount = 12 -) - -var largeBenchmarkSizes = []int{10, 100, 1000} - -type incrementalScenarioBenchmarkCase struct { - name string - mutate func(tb testing.TB, root string) - measure func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace - wantErr bool -} - -type incrementalScenarioTrace struct { - total time.Duration - labels map[string]time.Duration -} - -type incrementalScenarioBudget struct { - total time.Duration - validateLocal time.Duration - validateExt time.Duration - validateTouch time.Duration - validateTouchHit time.Duration - outputs time.Duration - generateLoad time.Duration - localFastpath time.Duration -} - -type largeRepoPerformanceBudget struct { - shapeTotal time.Duration - localLoad time.Duration - parse time.Duration - typecheck time.Duration - generate time.Duration - knownToggle time.Duration -} - -func BenchmarkGenerateIncrementalFirstSeenShapeChange(b *testing.B) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - - for i := 0; i < b.N; i++ { - cacheRoot := b.TempDir() - osTempDir = func() string { return cacheRoot } - - root := b.TempDir() - writeIncrementalBenchmarkModule(b, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - b.Fatalf("baseline Generate returned errors: %v", errs) - } - - writeBenchmarkFile(b, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeBenchmarkFile(b, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - b.StartTimer() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - b.StopTimer() - if len(errs) > 0 { - b.Fatalf("incremental shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - b.Fatalf("unexpected Generate results: %+v", gens) - } - } -} - -func BenchmarkGenerateIncrementalScenarioMatrix(b *testing.B) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - for _, scenario := range incrementalScenarioBenchmarks() { - scenario := scenario - b.Run(scenario.name, func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StartTimer() - _ = measureIncrementalScenarioOnce(b, repoRoot, scenario) - b.StopTimer() - } - }) - } -} - -func TestPrintIncrementalScenarioBenchmarkTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_SCENARIOS") == "" { - t.Skip("set WIRE_BENCH_SCENARIOS=1 to print the incremental scenario benchmark table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := [][]string{{ - "scenario", - "total", - "local pkgs", - "external", - "touched", - "touch hit", - "outputs", - "gen load", - "local fastpath", - }} - for _, scenario := range incrementalScenarioBenchmarks() { - trace := measureIncrementalScenarioOnce(t, repoRoot, scenario) - rows = append(rows, []string{ - scenario.name, - formatBenchmarkDuration(trace.total), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_local_packages")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_external_files")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.validate_touched_cache_hit")), - formatBenchmarkDuration(trace.label("incremental.preload_manifest.outputs")), - formatBenchmarkDuration(trace.label("generate.load")), - formatBenchmarkDuration(trace.label("incremental.local_fastpath.load")), - }) - } - fmt.Print(renderASCIITable(rows)) -} - -func TestIncrementalScenarioPerformanceBudgets(t *testing.T) { - if os.Getenv("WIRE_PERF_BUDGETS") == "" { - t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - budgets := incrementalScenarioPerformanceBudgets() - for _, scenario := range incrementalScenarioBenchmarks() { - scenario := scenario - budget, ok := budgets[scenario.name] - if !ok { - t.Fatalf("missing performance budget for scenario %q", scenario.name) - } - t.Run(scenario.name, func(t *testing.T) { - trace := measureIncrementalScenarioMedian(t, repoRoot, scenario, 5) - assertScenarioBudget(t, trace, budget) - }) - } -} - -func TestLargeRepoPerformanceBudgets(t *testing.T) { - if os.Getenv("WIRE_PERF_BUDGETS") == "" { - t.Skip("set WIRE_PERF_BUDGETS=1 to enforce incremental scenario performance budgets") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - budgets := largeRepoPerformanceBudgets() - for _, packageCount := range largeBenchmarkSizes { - packageCount := packageCount - budget, ok := budgets[packageCount] - if !ok { - t.Fatalf("missing large-repo performance budget for size %d", packageCount) - } - t.Run(strconv.Itoa(packageCount), func(t *testing.T) { - trace := measureLargeRepoShapeChangeTraceMedian(t, repoRoot, packageCount, true, 3) - checkBudgetDuration(t, "shape_total", trace.total, budget.shapeTotal) - checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localLoad) - checkBudgetDuration(t, "parse", trace.label("incremental.local_fastpath.parse"), budget.parse) - checkBudgetDuration(t, "typecheck", trace.label("incremental.local_fastpath.typecheck"), budget.typecheck) - checkBudgetDuration(t, "generate", trace.label("incremental.local_fastpath.generate"), budget.generate) - - knownToggle := measureLargeRepoKnownToggleMedian(t, repoRoot, packageCount, 3) - checkBudgetDuration(t, "known_toggle", knownToggle, budget.knownToggle) - }) - } -} - -func BenchmarkGenerateLargeRepoNormalShapeChange(b *testing.B) { - runLargeRepoShapeChangeBenchmarks(b, false) -} - -func BenchmarkGenerateLargeRepoIncrementalShapeChange(b *testing.B) { - runLargeRepoShapeChangeBenchmarks(b, true) -} - -func TestPrintLargeRepoBenchmarkComparisonTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_TABLE") == "" { - t.Skip("set WIRE_BENCH_TABLE=1 to print the large-repo benchmark comparison table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := make([]largeRepoBenchmarkRow, 0, len(largeBenchmarkSizes)) - for _, packageCount := range largeBenchmarkSizes { - coldNormal := measureLargeRepoColdOnce(t, repoRoot, packageCount, false) - coldIncremental := measureLargeRepoColdOnce(t, repoRoot, packageCount, true) - normal := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, false) - incremental := measureLargeRepoShapeChangeOnce(t, repoRoot, packageCount, true) - knownToggle := measureLargeRepoKnownToggleOnce(t, repoRoot, packageCount) - rows = append(rows, largeRepoBenchmarkRow{ - packageCount: packageCount, - coldNormal: coldNormal, - coldIncremental: coldIncremental, - normal: normal, - incremental: incremental, - knownToggle: knownToggle, - }) - } - - table := [][]string{{ - "repo size", - "cold old", - "cold new", - "cold delta", - "shape old", - "shape new", - "shape delta", - "known toggle", - "cold speedup", - "shape speedup", - }} - for _, row := range rows { - table = append(table, []string{ - strconv.Itoa(row.packageCount), - formatBenchmarkDuration(row.coldNormal), - formatBenchmarkDuration(row.coldIncremental), - formatPercentImprovement(row.coldNormal, row.coldIncremental), - formatBenchmarkDuration(row.normal), - formatBenchmarkDuration(row.incremental), - formatPercentImprovement(row.normal, row.incremental), - formatBenchmarkDuration(row.knownToggle), - fmt.Sprintf("%.2fx", speedupRatio(row.coldNormal, row.coldIncremental)), - fmt.Sprintf("%.2fx", speedupRatio(row.normal, row.incremental)), - }) - } - fmt.Print(renderASCIITable(table)) -} - -func TestPrintLargeRepoShapeChangeBreakdownTable(t *testing.T) { - if os.Getenv("WIRE_BENCH_BREAKDOWN") == "" { - t.Skip("set WIRE_BENCH_BREAKDOWN=1 to print the large-repo shape-change breakdown table") - } - - cacheHooksMu.Lock() - state := saveCacheHooks() - t.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(t) - rows := [][]string{{ - "repo size", - "old total", - "old base load", - "old typed load", - "new total", - "new local load", - "new parse", - "new typecheck", - "new injector solve", - "new format", - "new generate", - "speedup", - }} - for _, packageCount := range largeBenchmarkSizes { - normal := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, false) - incremental := measureLargeRepoShapeChangeTraceOnce(t, repoRoot, packageCount, true) - rows = append(rows, []string{ - strconv.Itoa(packageCount), - formatBenchmarkDuration(normal.total), - formatBenchmarkDuration(normal.label("load.packages.base.load")), - formatBenchmarkDuration(normal.label("load.packages.lazy.load")), - formatBenchmarkDuration(incremental.total), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.load")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.parse")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.typecheck")), - formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.injectors")), - formatBenchmarkDuration(incremental.label("generate.package.example.com/app/app.format")), - formatBenchmarkDuration(incremental.label("incremental.local_fastpath.generate")), - fmt.Sprintf("%.2fx", speedupRatio(normal.total, incremental.total)), - }) - } - fmt.Print(renderASCIITable(rows)) -} - -func writeIncrementalBenchmarkModule(tb testing.TB, repoRoot string, root string) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) -} - -func TestGenerateIncrementalLargeRepoShapeChangeMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := benchmarkRepoRoot(t) - root := t.TempDir() - writeLargeBenchmarkModule(t, repoRoot, root, largeBenchmarkTestPackageCount) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(t, root, largeBenchmarkTestPackageCount/2) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental large-repo Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected large-repo shape change to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal large-repo Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal results: %+v", normalGens) - } - if incrementalGens[0].OutputPath != normalGens[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("large-repo shape-changing incremental output differs from normal Generate output") - } -} - -func runLargeRepoShapeChangeBenchmarks(b *testing.B, incremental bool) { - cacheHooksMu.Lock() - state := saveCacheHooks() - b.Cleanup(func() { - restoreCacheHooks(state) - cacheHooksMu.Unlock() - }) - - repoRoot := benchmarkRepoRoot(b) - for _, packageCount := range largeBenchmarkSizes { - packageCount := packageCount - b.Run(fmt.Sprintf("size=%d", packageCount), func(b *testing.B) { - for i := 0; i < b.N; i++ { - b.StartTimer() - _ = measureLargeRepoShapeChangeOnce(b, repoRoot, packageCount, incremental) - b.StopTimer() - } - }) - } -} - -func incrementalScenarioBenchmarks() []incrementalScenarioBenchmarkCase { - return []incrementalScenarioBenchmarkCase{ - { - name: "preload_unchanged", - mutate: func(testing.TB, string) {}, - }, - { - name: "preload_whitespace_only_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "", - "func New(msg string) *Foo {", - "", - "\treturn &Foo{Message: helper(msg)}", - "", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_body_only_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string {", - "\treturn helper(SQLText)", - "}", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_body_only_repeat_change", - measure: func(tb testing.TB, root string, env []string, ctx context.Context) incrementalScenarioTrace { - writeBodyOnlyScenarioVariant(tb, root, "b") - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("warm changed variant Generate returned errors: %v", errs) - } - writeBodyOnlyScenarioVariant(tb, root, "a") - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("reset variant Generate returned errors: %v", errs) - } - writeBodyOnlyScenarioVariant(tb, root, "b") - trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} - timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - if len(errs) > 0 { - tb.Fatalf("%s: Generate returned errors: %v", "preload_body_only_repeat_change", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("%s: unexpected Generate results: %+v", "preload_body_only_repeat_change", gens) - } - return trace - }, - }, - { - name: "local_fastpath_method_body_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func (f Foo) Summary() string {", - "\treturn helper(f.Message)", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_const_value_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_var_initializer_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 2", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_add_top_level_helper", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func NewTag() string { return \"tag\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "preload_import_only_implementation_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "import \"fmt\"", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return fmt.Sprint(msg) }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_signature_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 7", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func NewCount() int { return defaultCount }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: count}", - "}", - "", - }, "\n")) - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_struct_field_addition", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: defaultCount}", - "}", - "", - }, "\n")) - }, - }, - { - name: "local_fastpath_interface_method_addition", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Fooer interface {", - "\tMessage() string", - "\tCount() int", - "}", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - }, - { - name: "fallback_invalid_body_change", - mutate: func(tb testing.TB, root string) { - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return missing }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - wantErr: true, - }, - } -} - -func incrementalScenarioPerformanceBudgets() map[string]incrementalScenarioBudget { - return map[string]incrementalScenarioBudget{ - "preload_unchanged": { - total: 300 * time.Millisecond, - validateLocal: 25 * time.Millisecond, - validateExt: 25 * time.Millisecond, - validateTouch: 5 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_whitespace_only_change": { - total: 300 * time.Millisecond, - validateLocal: 25 * time.Millisecond, - validateExt: 25 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_body_only_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_body_only_repeat_change": { - total: 150 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 5 * time.Millisecond, - validateTouchHit: 5 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "local_fastpath_method_body_change": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "preload_import_only_implementation_change": { - total: 150 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 50 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_const_value_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "preload_var_initializer_change": { - total: 400 * time.Millisecond, - validateLocal: 40 * time.Millisecond, - validateExt: 40 * time.Millisecond, - validateTouch: 250 * time.Millisecond, - outputs: 5 * time.Millisecond, - }, - "local_fastpath_add_top_level_helper": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_signature_change": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_struct_field_addition": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "local_fastpath_interface_method_addition": { - total: 500 * time.Millisecond, - validateLocal: 60 * time.Millisecond, - validateExt: 60 * time.Millisecond, - localFastpath: 300 * time.Millisecond, - }, - "fallback_invalid_body_change": { - total: 800 * time.Millisecond, - generateLoad: 500 * time.Millisecond, - }, - } -} - -func measureIncrementalScenarioOnce(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase) incrementalScenarioTrace { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeIncrementalScenarioBenchmarkModule(tb, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - if scenario.measure != nil { - return scenario.measure(tb, root, env, ctx) - } - - scenario.mutate(tb, root) - - trace := incrementalScenarioTrace{labels: make(map[string]time.Duration)} - timedCtx := WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - - if scenario.wantErr { - if len(errs) == 0 { - tb.Fatalf("%s: expected Generate errors", scenario.name) - } - if len(gens) != 0 { - tb.Fatalf("%s: expected no generated results on error, got %+v", scenario.name, gens) - } - return trace - } - - if len(errs) > 0 { - tb.Fatalf("%s: Generate returned errors: %v", scenario.name, errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("%s: unexpected Generate results: %+v", scenario.name, gens) - } - return trace -} - -func writeIncrementalScenarioBenchmarkModule(tb testing.TB, repoRoot string, root string) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeBodyOnlyScenarioVariant(tb, root, "green") -} - -func writeBodyOnlyScenarioVariant(tb testing.TB, root string, value string) { - tb.Helper() - writeBenchmarkFile(tb, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "const SQLText = \"" + value + "\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - - writeBenchmarkFile(tb, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) -} - -func measureIncrementalScenarioMedian(tb testing.TB, repoRoot string, scenario incrementalScenarioBenchmarkCase, samples int) incrementalScenarioTrace { - tb.Helper() - if samples <= 0 { - samples = 1 - } - traces := make([]incrementalScenarioTrace, 0, samples) - for i := 0; i < samples; i++ { - traces = append(traces, measureIncrementalScenarioOnce(tb, repoRoot, scenario)) - } - sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) - return traces[len(traces)/2] -} - -func assertScenarioBudget(t *testing.T, trace incrementalScenarioTrace, budget incrementalScenarioBudget) { - t.Helper() - checkBudgetDuration(t, "total", trace.total, budget.total) - checkBudgetDuration(t, "validate_local_packages", trace.label("incremental.preload_manifest.validate_local_packages"), budget.validateLocal) - checkBudgetDuration(t, "validate_external_files", trace.label("incremental.preload_manifest.validate_external_files"), budget.validateExt) - checkBudgetDuration(t, "validate_touched", trace.label("incremental.preload_manifest.validate_touched"), budget.validateTouch) - checkBudgetDuration(t, "validate_touched_cache_hit", trace.label("incremental.preload_manifest.validate_touched_cache_hit"), budget.validateTouchHit) - checkBudgetDuration(t, "outputs", trace.label("incremental.preload_manifest.outputs"), budget.outputs) - checkBudgetDuration(t, "generate_load", trace.label("generate.load"), budget.generateLoad) - checkBudgetDuration(t, "local_fastpath_load", trace.label("incremental.local_fastpath.load"), budget.localFastpath) -} - -func checkBudgetDuration(t *testing.T, name string, got time.Duration, max time.Duration) { - t.Helper() - if max <= 0 { - return - } - if got > max { - t.Fatalf("%s exceeded budget: got=%s max=%s", name, got, max) - } -} - -func (s incrementalScenarioTrace) label(name string) time.Duration { - if s.labels == nil { - return 0 - } - return s.labels[name] -} - -type largeRepoBenchmarkRow struct { - packageCount int - coldNormal time.Duration - coldIncremental time.Duration - normal time.Duration - incremental time.Duration - knownToggle time.Duration -} - -type shapeChangeTrace struct { - total time.Duration - labels map[string]time.Duration -} - -func largeRepoPerformanceBudgets() map[int]largeRepoPerformanceBudget { - return map[int]largeRepoPerformanceBudget{ - 10: { - shapeTotal: 45 * time.Millisecond, - localLoad: 3 * time.Millisecond, - parse: 500 * time.Microsecond, - typecheck: 4 * time.Millisecond, - generate: 3 * time.Millisecond, - knownToggle: 3 * time.Millisecond, - }, - 100: { - shapeTotal: 35 * time.Millisecond, - localLoad: 20 * time.Millisecond, - parse: 1500 * time.Microsecond, - typecheck: 12 * time.Millisecond, - generate: 20 * time.Millisecond, - knownToggle: 15 * time.Millisecond, - }, - 1000: { - shapeTotal: 260 * time.Millisecond, - localLoad: 110 * time.Millisecond, - parse: 4 * time.Millisecond, - typecheck: 70 * time.Millisecond, - generate: 180 * time.Millisecond, - knownToggle: 90 * time.Millisecond, - }, - } -} - -func measureLargeRepoShapeChangeOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, packageCount/2) - - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoShapeChangeTraceOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) shapeChangeTrace { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, packageCount/2) - - trace := shapeChangeTrace{labels: make(map[string]time.Duration)} - ctx = WithTiming(ctx, func(label string, dur time.Duration) { - trace.labels[label] += dur - }) - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - trace.total = time.Since(start) - if len(errs) > 0 { - tb.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected Generate results: %+v", gens) - } - return trace -} - -func measureLargeRepoShapeChangeTraceMedian(tb testing.TB, repoRoot string, packageCount int, incremental bool, samples int) shapeChangeTrace { - tb.Helper() - if samples <= 0 { - samples = 1 - } - traces := make([]shapeChangeTrace, 0, samples) - for i := 0; i < samples; i++ { - traces = append(traces, measureLargeRepoShapeChangeTraceOnce(tb, repoRoot, packageCount, incremental)) - } - sort.Slice(traces, func(i, j int) bool { return traces[i].total < traces[j].total }) - return traces[len(traces)/2] -} - -func (s shapeChangeTrace) label(name string) time.Duration { - if s.labels == nil { - return 0 - } - return s.labels[name] -} - -func measureLargeRepoColdOnce(tb testing.TB, repoRoot string, packageCount int, incremental bool) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - if incremental { - ctx = WithIncremental(ctx, true) - } - - start := time.Now() - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("cold Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected cold Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoKnownToggleOnce(tb testing.TB, repoRoot string, packageCount int) time.Duration { - tb.Helper() - - cacheRoot := tb.TempDir() - osTempDir = func() string { return cacheRoot } - - root := tb.TempDir() - writeLargeBenchmarkModule(tb, repoRoot, root, packageCount) - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - mutatedIndex := packageCount / 2 - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - tb.Fatalf("baseline Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(tb, root, mutatedIndex) - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - tb.Fatalf("mutated Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected mutated Generate results: %+v", gens) - } - - writeLargeBenchmarkPackage(tb, root, mutatedIndex, false) - - start := time.Now() - gens, errs = Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - dur := time.Since(start) - if len(errs) > 0 { - tb.Fatalf("toggle Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - tb.Fatalf("unexpected toggle Generate results: %+v", gens) - } - return dur -} - -func measureLargeRepoKnownToggleMedian(tb testing.TB, repoRoot string, packageCount int, samples int) time.Duration { - tb.Helper() - if samples <= 0 { - samples = 1 - } - values := make([]time.Duration, 0, samples) - for i := 0; i < samples; i++ { - values = append(values, measureLargeRepoKnownToggleOnce(tb, repoRoot, packageCount)) - } - sort.Slice(values, func(i, j int) bool { return values[i] < values[j] }) - return values[len(values)/2] -} - -func formatPercentImprovement(normal time.Duration, incremental time.Duration) string { - if normal <= 0 { - return "0.0%" - } - improvement := 100 * (float64(normal-incremental) / float64(normal)) - return fmt.Sprintf("%.1f%%", improvement) -} - -func speedupRatio(normal time.Duration, incremental time.Duration) float64 { - if incremental <= 0 { - return 0 - } - return float64(normal) / float64(incremental) -} - -func formatBenchmarkDuration(d time.Duration) string { - switch { - case d >= time.Second: - return fmt.Sprintf("%.2fs", d.Seconds()) - case d >= time.Millisecond: - return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) - case d >= time.Microsecond: - return fmt.Sprintf("%.2fus", float64(d)/float64(time.Microsecond)) - default: - return d.String() - } -} - -func writeLargeBenchmarkModule(tb testing.TB, repoRoot string, root string, packageCount int) { - tb.Helper() - - writeBenchmarkFile(tb, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - wireImports := []string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"github.com/goforj/wire\"", - } - appImports := []string{ - "package app", - "", - "import (", - } - buildArgs := []string{"\twire.Build("} - argNames := make([]string, 0, packageCount) - for i := 0; i < packageCount; i++ { - pkgName := fmt.Sprintf("layer%02d", i) - wireImports = append(wireImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) - appImports = append(appImports, fmt.Sprintf("\t%s %q", pkgName, "example.com/app/"+pkgName)) - buildArgs = append(buildArgs, fmt.Sprintf("\t\t%s.NewSet,", pkgName)) - argNames = append(argNames, fmt.Sprintf("dep%02d *%s.Token", i, pkgName)) - } - wireImports = append(wireImports, ")", "") - appImports = append(appImports, ")", "") - wireFile := append([]string{}, wireImports...) - wireFile = append(wireFile, "func Init() *App {") - wireFile = append(wireFile, buildArgs...) - wireFile = append(wireFile, "\t\tNewApp,", "\t)", "\treturn nil", "}", "") - writeBenchmarkFile(tb, filepath.Join(root, "app", "wire.go"), strings.Join(wireFile, "\n")) - - appGo := append(appImports[:len(appImports)-2], // reuse imports without trailing blank line - ")", - "", - "type App struct {", - "\tCount int", - "}", - "", - fmt.Sprintf("func NewApp(%s) *App {", strings.Join(argNames, ", ")), - fmt.Sprintf("\treturn &App{Count: %d}", packageCount), - "}", - "", - ) - writeBenchmarkFile(tb, filepath.Join(root, "app", "app.go"), strings.Join(appGo, "\n")) - - for i := 0; i < packageCount; i++ { - writeLargeBenchmarkPackage(tb, root, i, false) - } -} - -func mutateLargeBenchmarkModule(tb testing.TB, root string, mutatedIndex int) { - tb.Helper() - writeLargeBenchmarkPackage(tb, root, mutatedIndex, true) -} - -func writeLargeBenchmarkPackage(tb testing.TB, root string, index int, mutated bool) { - tb.Helper() - - pkgName := fmt.Sprintf("layer%02d", index) - pkgDir := filepath.Join(root, pkgName) - - writeBenchmarkFile(tb, filepath.Join(pkgDir, "helpers.go"), renderLargeBenchmarkHelpers(pkgName, index, mutated)) - writeBenchmarkFile(tb, filepath.Join(pkgDir, "wire.go"), renderLargeBenchmarkWire(pkgName, mutated)) -} - -func renderLargeBenchmarkHelpers(pkgName string, index int, mutated bool) string { - lines := []string{ - "package " + pkgName, - "", - "import (", - "\t\"fmt\"", - "\t\"strconv\"", - "\t\"strings\"", - ")", - "", - "type Config struct {", - "\tLabel string", - "}", - "", - "type Weight int", - "", - "type Token struct {", - "\tConfig Config", - "\tWeight Weight", - "}", - "", - fmt.Sprintf("func NewConfig() Config { return Config{Label: %q} }", pkgName), - "", - } - if mutated { - lines = append(lines, - fmt.Sprintf("func NewWeight() Weight { return Weight(%d) }", index+100), - "", - "func New(cfg Config, weight Weight) *Token {", - fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), - "\treturn &Token{Config: cfg, Weight: weight}", - "}", - "", - ) - } else { - lines = append(lines, - "func New(cfg Config) *Token {", - fmt.Sprintf("\t_ = helper%02d()", largeBenchmarkHelperCount-1), - "\treturn &Token{Config: cfg}", - "}", - "", - ) - } - for i := 0; i < largeBenchmarkHelperCount; i++ { - lines = append(lines, fmt.Sprintf("func helper%02d() string {", i)) - lines = append(lines, fmt.Sprintf("\treturn strings.ToUpper(fmt.Sprintf(\"%%s-%%d\", %q, %d)) + strconv.Itoa(%d)", pkgName, i, index+i)) - lines = append(lines, "}", "") - } - return strings.Join(lines, "\n") -} - -func renderLargeBenchmarkWire(pkgName string, mutated bool) string { - lines := []string{ - "package " + pkgName, - "", - "import (", - "\t\"github.com/goforj/wire\"", - ")", - "", - } - if mutated { - lines = append(lines, "var NewSet = wire.NewSet(NewConfig, NewWeight, New)", "") - } else { - lines = append(lines, "var NewSet = wire.NewSet(NewConfig, New)", "") - } - return strings.Join(lines, "\n") -} - -func strconvQuote(s string) string { - return fmt.Sprintf("%q", s) -} - -func benchmarkRepoRoot(tb testing.TB) string { - tb.Helper() - wd, err := os.Getwd() - if err != nil { - tb.Fatalf("Getwd failed: %v", err) - } - repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) - if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { - tb.Fatalf("repo root not found at %s: %v", repoRoot, err) - } - return repoRoot -} - -func writeBenchmarkFile(tb testing.TB, path string, content string) { - tb.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - tb.Fatalf("MkdirAll failed: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - tb.Fatalf("WriteFile failed: %v", err) - } -} - -func renderASCIITable(rows [][]string) string { - if len(rows) == 0 { - return "" - } - widths := make([]int, len(rows[0])) - for _, row := range rows { - for i, cell := range row { - if width := utf8.RuneCountInString(cell); width > widths[i] { - widths[i] = width - } - } - } - var b strings.Builder - border := func() { - b.WriteByte('+') - for _, width := range widths { - b.WriteString(strings.Repeat("-", width+2)) - b.WriteByte('+') - } - b.WriteByte('\n') - } - writeRow := func(row []string) { - b.WriteByte('|') - for i, cell := range row { - b.WriteByte(' ') - b.WriteString(cell) - b.WriteString(strings.Repeat(" ", widths[i]-utf8.RuneCountInString(cell)+1)) - b.WriteByte('|') - } - b.WriteByte('\n') - } - border() - writeRow(rows[0]) - border() - for _, row := range rows[1:] { - writeRow(row) - } - border() - return b.String() -} diff --git a/internal/wire/incremental_fingerprint.go b/internal/wire/incremental_fingerprint.go deleted file mode 100644 index be39982..0000000 --- a/internal/wire/incremental_fingerprint.go +++ /dev/null @@ -1,674 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/ast" - "go/parser" - "go/printer" - "go/token" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalFingerprintVersion = "wire-incremental-v3" - -type packageFingerprint struct { - Version string - WD string - Tags string - PkgPath string - Files []cacheFile - Dirs []cacheFile - ContentHash string - ShapeHash string - LocalImports []string -} - -type fingerprintStats struct { - localPackages int - metaHits int - metaMisses int - unchanged int - changed int -} - -type incrementalFingerprintSnapshot struct { - stats fingerprintStats - changed []string - touched []string - fingerprints map[string]*packageFingerprint -} - -func analyzeIncrementalFingerprints(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - if !IncrementalEnabled(ctx, env) { - return nil - } - start := timeNow() - snapshot := collectIncrementalFingerprints(wd, tags, pkgs) - debugf(ctx, "incremental.fingerprint local_pkgs=%d meta_hits=%d meta_misses=%d unchanged=%d changed=%d total=%s", - snapshot.stats.localPackages, - snapshot.stats.metaHits, - snapshot.stats.metaMisses, - snapshot.stats.unchanged, - snapshot.stats.changed, - timeSince(start), - ) - if len(snapshot.changed) > 0 { - debugf(ctx, "incremental.fingerprint changed_pkgs=%s", strings.Join(snapshot.changed, ", ")) - } - return snapshot -} - -func collectIncrementalFingerprints(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - all := collectAllPackages(pkgs) - moduleRoot := findModuleRoot(wd) - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: make(map[string]*packageFingerprint), - } - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - snapshot.stats.localPackages++ - files := packageFingerprintFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - snapshot.stats.metaMisses++ - continue - } - key := incrementalFingerprintKey(wd, tags, pkg.PkgPath) - if prev, ok := readIncrementalFingerprint(key); ok && incrementalFingerprintMetaMatches(prev, wd, tags, pkg.PkgPath, metaFiles) { - snapshot.stats.metaHits++ - snapshot.stats.unchanged++ - snapshot.fingerprints[pkg.PkgPath] = prev - continue - } - snapshot.stats.metaMisses++ - snapshot.touched = append(snapshot.touched, pkg.PkgPath) - fp, err := buildPackageFingerprint(wd, tags, pkg, metaFiles) - if err != nil { - continue - } - prev, hadPrev := readIncrementalFingerprint(key) - writeIncrementalFingerprint(key, fp) - snapshot.fingerprints[pkg.PkgPath] = fp - if hadPrev && incrementalFingerprintEquivalent(prev, fp) { - snapshot.stats.unchanged++ - continue - } - snapshot.stats.changed++ - snapshot.changed = append(snapshot.changed, pkg.PkgPath) - } - sort.Strings(snapshot.changed) - sort.Strings(snapshot.touched) - return snapshot -} - -func buildIncrementalManifestSnapshotFromPackages(wd string, tags string, pkgs []*packages.Package) *incrementalFingerprintSnapshot { - all := collectAllPackages(pkgs) - moduleRoot := findModuleRoot(wd) - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: make(map[string]*packageFingerprint), - } - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - files := packageFingerprintFiles(pkg) - if len(files) == 0 { - continue - } - sort.Strings(files) - metaFiles, err := buildCacheFiles(files) - if err != nil { - continue - } - shapeHash, err := packageShapeHashFromSyntax(pkg, files) - if err != nil { - continue - } - localImports := make([]string, 0, len(pkg.Imports)) - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) == "local" { - localImports = append(localImports, imp.PkgPath) - } - } - sort.Strings(localImports) - snapshot.fingerprints[pkg.PkgPath] = &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: packageCacheScope(wd), - Tags: tags, - PkgPath: pkg.PkgPath, - Files: metaFiles, - Dirs: mustBuildPackageDirCacheFiles(files), - ContentHash: mustHashPackageFiles(files), - ShapeHash: shapeHash, - LocalImports: localImports, - } - } - if len(snapshot.fingerprints) == 0 { - return nil - } - return snapshot -} - -func packageFingerprintFiles(pkg *packages.Package) []string { - if pkg == nil { - return nil - } - if len(pkg.CompiledGoFiles) > 0 { - return append([]string(nil), pkg.CompiledGoFiles...) - } - return append([]string(nil), pkg.GoFiles...) -} - -func packageFingerprintDirs(files []string) []string { - if len(files) == 0 { - return nil - } - dirs := make([]string, 0, len(files)) - seen := make(map[string]struct{}, len(files)) - for _, name := range files { - dir := filepath.Clean(filepath.Dir(name)) - if _, ok := seen[dir]; ok { - continue - } - seen[dir] = struct{}{} - dirs = append(dirs, dir) - } - sort.Strings(dirs) - return dirs -} - -func mustBuildPackageDirCacheFiles(files []string) []cacheFile { - dirs := packageFingerprintDirs(files) - if len(dirs) == 0 { - return nil - } - meta, err := buildCacheFiles(dirs) - if err != nil { - return nil - } - return meta -} - -func mustHashPackageFiles(files []string) string { - if len(files) == 0 { - return "" - } - hash, err := hashFiles(files) - if err != nil { - return "" - } - return hash -} - -func incrementalFingerprintEquivalent(a, b *packageFingerprint) bool { - if a == nil || b == nil { - return false - } - if a.ShapeHash != b.ShapeHash || a.PkgPath != b.PkgPath || a.Tags != b.Tags || a.WD != b.WD { - return false - } - if len(a.LocalImports) != len(b.LocalImports) { - return false - } - for i := range a.LocalImports { - if a.LocalImports[i] != b.LocalImports[i] { - return false - } - } - return true -} - -func incrementalFingerprintMetaMatches(prev *packageFingerprint, wd string, tags string, pkgPath string, files []cacheFile) bool { - if prev == nil || prev.Version != incrementalFingerprintVersion { - return false - } - if prev.WD != packageCacheScope(wd) || prev.Tags != tags || prev.PkgPath != pkgPath { - return false - } - if len(prev.Files) != len(files) { - return false - } - for i := range prev.Files { - if prev.Files[i] != files[i] { - return false - } - } - return true -} - -func buildPackageFingerprint(wd string, tags string, pkg *packages.Package, files []cacheFile) (*packageFingerprint, error) { - shapeHash, err := packageShapeHash(packageFingerprintFiles(pkg)) - if err != nil { - return nil, err - } - localImports := make([]string, 0, len(pkg.Imports)) - moduleRoot := findModuleRoot(wd) - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) == "local" { - localImports = append(localImports, imp.PkgPath) - } - } - sort.Strings(localImports) - return &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: packageCacheScope(wd), - Tags: tags, - PkgPath: pkg.PkgPath, - Files: append([]cacheFile(nil), files...), - Dirs: mustBuildPackageDirCacheFiles(packageFingerprintFiles(pkg)), - ContentHash: mustHashPackageFiles(packageFingerprintFiles(pkg)), - ShapeHash: shapeHash, - LocalImports: localImports, - }, nil -} - -func packageShapeHash(files []string) (string, error) { - fset := token.NewFileSet() - var buf bytes.Buffer - for _, name := range files { - file, err := parser.ParseFile(fset, name, nil, parser.SkipObjectResolution) - if err != nil { - return "", err - } - writeSyntaxShapeHash(&buf, fset, file) - buf.WriteByte(0) - } - sum := sha256.Sum256(buf.Bytes()) - return fmt.Sprintf("%x", sum[:]), nil -} - -func packageShapeHashFromSyntax(pkg *packages.Package, files []string) (string, error) { - if pkg == nil || len(pkg.Syntax) == 0 || pkg.Fset == nil { - return packageShapeHash(files) - } - var buf bytes.Buffer - for _, file := range pkg.Syntax { - if file == nil { - continue - } - writeSyntaxShapeHash(&buf, pkg.Fset, file) - buf.WriteByte(0) - } - sum := sha256.Sum256(buf.Bytes()) - return fmt.Sprintf("%x", sum[:]), nil -} - -func writeSyntaxShapeHash(buf *bytes.Buffer, fset *token.FileSet, file *ast.File) { - if file == nil || buf == nil || fset == nil { - return - } - usedImports := usedImportNamesInShape(file) - if file.Name != nil { - buf.WriteString("package ") - buf.WriteString(file.Name.Name) - buf.WriteByte('\n') - } - for _, decl := range file.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - writeNodeHash(buf, fset, decl.Recv) - buf.WriteByte(' ') - if decl.Name != nil { - buf.WriteString(decl.Name.Name) - } - buf.WriteByte(' ') - writeNodeHash(buf, fset, decl.Type) - buf.WriteByte('\n') - case *ast.GenDecl: - if writeGenDeclShapeHash(buf, fset, decl, usedImports) { - buf.WriteByte('\n') - } - default: - writeNodeHash(buf, fset, decl) - buf.WriteByte('\n') - } - } -} - -func writeGenDeclShapeHash(buf *bytes.Buffer, fset *token.FileSet, decl *ast.GenDecl, usedImports map[string]struct{}) bool { - if buf == nil || fset == nil || decl == nil { - return false - } - var specBuf bytes.Buffer - wrote := false - for _, spec := range decl.Specs { - switch spec := spec.(type) { - case *ast.ImportSpec: - name := importName(spec) - if name == "_" || name == "." { - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Path) - specBuf.WriteByte('\n') - wrote = true - break - } - if _, ok := usedImports[name]; !ok { - continue - } - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Path) - case *ast.TypeSpec: - if spec.Name != nil { - specBuf.WriteString(spec.Name.Name) - } - specBuf.WriteByte(' ') - writeNodeHash(&specBuf, fset, spec.Type) - case *ast.ValueSpec: - for _, name := range spec.Names { - if name != nil { - specBuf.WriteString(name.Name) - } - specBuf.WriteByte(' ') - } - if spec.Type != nil { - writeNodeHash(&specBuf, fset, spec.Type) - } - default: - writeNodeHash(&specBuf, fset, spec) - } - specBuf.WriteByte('\n') - wrote = true - } - if !wrote { - return false - } - buf.WriteString(decl.Tok.String()) - buf.WriteByte(' ') - buf.Write(specBuf.Bytes()) - return true -} - -func usedImportNamesInShape(file *ast.File) map[string]struct{} { - used := make(map[string]struct{}) - if file == nil { - return used - } - record := func(node ast.Node) { - ast.Inspect(node, func(n ast.Node) bool { - sel, ok := n.(*ast.SelectorExpr) - if !ok { - return true - } - ident, ok := sel.X.(*ast.Ident) - if !ok || ident.Name == "" { - return true - } - used[ident.Name] = struct{}{} - return true - }) - } - for _, decl := range file.Decls { - switch decl := decl.(type) { - case *ast.FuncDecl: - if decl.Recv != nil { - record(decl.Recv) - } - if decl.Type != nil { - record(decl.Type) - } - case *ast.GenDecl: - for _, spec := range decl.Specs { - switch spec := spec.(type) { - case *ast.TypeSpec: - if spec.Type != nil { - record(spec.Type) - } - case *ast.ValueSpec: - if spec.Type != nil { - record(spec.Type) - } - } - } - } - } - return used -} - -func writeNodeHash(buf *bytes.Buffer, fset *token.FileSet, node interface{}) { - if buf == nil || fset == nil || node == nil { - return - } - _ = printer.Fprint(buf, fset, node) -} - -func stripFunctionBodies(file *ast.File) { - if file == nil { - return - } - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - fn.Body = nil - fn.Doc = nil - } - } -} - -func incrementalFingerprintKey(wd string, tags string, pkgPath string) string { - h := sha256.New() - h.Write([]byte(incrementalFingerprintVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalFingerprintPath(key string) string { - return filepath.Join(cacheDir(), key+".ifp") -} - -func readIncrementalFingerprint(key string) (*packageFingerprint, bool) { - data, err := osReadFile(incrementalFingerprintPath(key)) - if err != nil { - return nil, false - } - fp, err := decodeIncrementalFingerprint(data) - if err != nil { - return nil, false - } - return fp, true -} - -func writeIncrementalFingerprint(key string, fp *packageFingerprint) { - data, err := encodeIncrementalFingerprint(fp) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".ifp-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalFingerprintPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func encodeIncrementalFingerprint(fp *packageFingerprint) ([]byte, error) { - var buf bytes.Buffer - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - writeCacheFiles := func(files []cacheFile) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { - return err - } - for _, f := range files { - if err := writeString(f.Path); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { - return err - } - } - return nil - } - writeStrings := func(items []string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(items))); err != nil { - return err - } - for _, item := range items { - if err := writeString(item); err != nil { - return err - } - } - return nil - } - if fp == nil { - return nil, fmt.Errorf("nil fingerprint") - } - for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := writeCacheFiles(fp.Files); err != nil { - return nil, err - } - if err := writeStrings(fp.LocalImports); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func decodeIncrementalFingerprint(data []byte) (*packageFingerprint, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - readCacheFiles := func() ([]cacheFile, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]cacheFile, 0, n) - for i := uint32(0); i < n; i++ { - path, err := readString() - if err != nil { - return nil, err - } - var size int64 - if err := binary.Read(r, binary.LittleEndian, &size); err != nil { - return nil, err - } - var modTime int64 - if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { - return nil, err - } - out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) - } - return out, nil - } - readStrings := func() ([]string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]string, 0, n) - for i := uint32(0); i < n; i++ { - item, err := readString() - if err != nil { - return nil, err - } - out = append(out, item) - } - return out, nil - } - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - pkgPath, err := readString() - if err != nil { - return nil, err - } - shapeHash, err := readString() - if err != nil { - return nil, err - } - files, err := readCacheFiles() - if err != nil { - return nil, err - } - localImports, err := readStrings() - if err != nil { - return nil, err - } - return &packageFingerprint{ - Version: version, - WD: wd, - Tags: tags, - PkgPath: pkgPath, - ShapeHash: shapeHash, - Files: files, - LocalImports: localImports, - }, nil -} diff --git a/internal/wire/incremental_fingerprint_test.go b/internal/wire/incremental_fingerprint_test.go deleted file mode 100644 index 920d08e..0000000 --- a/internal/wire/incremental_fingerprint_test.go +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "os" - "path/filepath" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestPackageShapeHashIgnoresFunctionBodies(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("body-only change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestPackageShapeHashIgnoresConstValueChanges(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nconst SQLText = \"a\"\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nconst SQLText = \"b\"\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("const-value change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestPackageShapeHashIgnoresImplementationOnlyImportChanges(t *testing.T) { - dir := t.TempDir() - file := writeTempFile(t, dir, "pkg.go", "package p\n\nfunc Hello() string { return \"a\" }\n") - hash1, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash first failed: %v", err) - } - if err := os.WriteFile(file, []byte("package p\n\nimport \"fmt\"\n\nfunc Hello() string { return fmt.Sprint(\"a\") }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - hash2, err := packageShapeHash([]string{file}) - if err != nil { - t.Fatalf("packageShapeHash second failed: %v", err) - } - if hash1 != hash2 { - t.Fatalf("implementation-only import change should not affect shape hash: %q vs %q", hash1, hash2) - } -} - -func TestIncrementalFingerprintRoundTrip(t *testing.T) { - fp := &packageFingerprint{ - Version: incrementalFingerprintVersion, - WD: "/tmp/app", - Tags: "dev", - PkgPath: "example.com/app", - ShapeHash: "shape", - Files: []cacheFile{{Path: "/tmp/app/pkg.go", Size: 12, ModTime: 34}}, - LocalImports: []string{"example.com/dep"}, - } - data, err := encodeIncrementalFingerprint(fp) - if err != nil { - t.Fatalf("encodeIncrementalFingerprint failed: %v", err) - } - got, err := decodeIncrementalFingerprint(data) - if err != nil { - t.Fatalf("decodeIncrementalFingerprint failed: %v", err) - } - if !incrementalFingerprintEquivalent(fp, got) { - t.Fatalf("fingerprint mismatch after round-trip: got %+v want %+v", got, fp) - } - if len(got.Files) != 1 || got.Files[0] != fp.Files[0] { - t.Fatalf("file metadata mismatch after round-trip: got %+v want %+v", got.Files, fp.Files) - } -} - -func TestCollectIncrementalFingerprintsTreatsBodyOnlyChangeAsUnchanged(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") - file := filepath.Join(root, "app", "app.go") - writeFile(t, file, "package app\n\nfunc Hello() string { return \"a\" }\n") - pkg := &packages.Package{ - PkgPath: "example.com/app", - CompiledGoFiles: []string{file}, - GoFiles: []string{file}, - Imports: map[string]*packages.Package{}, - } - - snapshot := collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) - if snapshot.stats.changed != 1 || len(snapshot.changed) != 1 || snapshot.changed[0] != pkg.PkgPath { - t.Fatalf("first run stats=%+v changed=%v", snapshot.stats, snapshot.changed) - } - - if err := os.WriteFile(file, []byte("package app\n\nfunc Hello() string { return \"b\" }\n"), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - snapshot = collectIncrementalFingerprints(root, "", []*packages.Package{pkg}) - if snapshot.stats.unchanged != 1 { - t.Fatalf("body-only change should be unchanged by shape, stats=%+v changed=%v", snapshot.stats, snapshot.changed) - } - if len(snapshot.changed) != 0 { - t.Fatalf("body-only change should not report changed packages, got %v", snapshot.changed) - } -} diff --git a/internal/wire/incremental_graph.go b/internal/wire/incremental_graph.go deleted file mode 100644 index 37b3d0f..0000000 --- a/internal/wire/incremental_graph.go +++ /dev/null @@ -1,306 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalGraphVersion = "wire-incremental-graph-v1" - -type incrementalGraph struct { - Version string - WD string - Tags string - Roots []string - LocalReverse map[string][]string -} - -func analyzeIncrementalGraph(ctx context.Context, wd string, env []string, tags string, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) { - if !IncrementalEnabled(ctx, env) || snapshot == nil { - return - } - graph := buildIncrementalGraph(wd, tags, pkgs) - writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) - if len(snapshot.changed) == 0 { - return - } - affected := affectedRoots(graph, snapshot.changed) - if len(affected) > 0 { - debugf(ctx, "incremental.graph changed=%s affected_roots=%s", stringsJoin(snapshot.changed), stringsJoin(affected)) - } else { - debugf(ctx, "incremental.graph changed=%s affected_roots=", stringsJoin(snapshot.changed)) - } -} - -func buildIncrementalGraph(wd string, tags string, pkgs []*packages.Package) *incrementalGraph { - moduleRoot := findModuleRoot(wd) - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: packageCacheScope(wd), - Tags: tags, - Roots: make([]string, 0, len(pkgs)), - LocalReverse: make(map[string][]string), - } - for _, pkg := range pkgs { - if pkg == nil { - continue - } - graph.Roots = append(graph.Roots, pkg.PkgPath) - } - sort.Strings(graph.Roots) - for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - for _, imp := range pkg.Imports { - if classifyPackageLocation(moduleRoot, imp) != "local" { - continue - } - graph.LocalReverse[imp.PkgPath] = append(graph.LocalReverse[imp.PkgPath], pkg.PkgPath) - } - } - for path := range graph.LocalReverse { - sort.Strings(graph.LocalReverse[path]) - } - return graph -} - -func affectedRoots(graph *incrementalGraph, changed []string) []string { - if graph == nil || len(changed) == 0 { - return nil - } - rootSet := make(map[string]struct{}, len(graph.Roots)) - for _, root := range graph.Roots { - rootSet[root] = struct{}{} - } - seen := make(map[string]struct{}) - queue := append([]string(nil), changed...) - affected := make(map[string]struct{}) - for len(queue) > 0 { - cur := queue[0] - queue = queue[1:] - if _, ok := seen[cur]; ok { - continue - } - seen[cur] = struct{}{} - if _, ok := rootSet[cur]; ok { - affected[cur] = struct{}{} - } - for _, next := range graph.LocalReverse[cur] { - if _, ok := seen[next]; !ok { - queue = append(queue, next) - } - } - } - out := make([]string, 0, len(affected)) - for root := range affected { - out = append(out, root) - } - sort.Strings(out) - return out -} - -func incrementalGraphKey(wd string, tags string, roots []string) string { - h := sha256.New() - h.Write([]byte(incrementalGraphVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - for _, root := range roots { - h.Write([]byte(root)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalGraphPath(key string) string { - return filepath.Join(cacheDir(), key+".igr") -} - -func writeIncrementalGraph(key string, graph *incrementalGraph) { - data, err := encodeIncrementalGraph(graph) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".igr-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalGraphPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func readIncrementalGraph(key string) (*incrementalGraph, bool) { - data, err := osReadFile(incrementalGraphPath(key)) - if err != nil { - return nil, false - } - graph, err := decodeIncrementalGraph(data) - if err != nil { - return nil, false - } - return graph, true -} - -func encodeIncrementalGraph(graph *incrementalGraph) ([]byte, error) { - if graph == nil { - return nil, fmt.Errorf("nil incremental graph") - } - var buf bytes.Buffer - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - for _, s := range []string{graph.Version, graph.WD, graph.Tags} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(graph.Roots))); err != nil { - return nil, err - } - for _, root := range graph.Roots { - if err := writeString(root); err != nil { - return nil, err - } - } - keys := make([]string, 0, len(graph.LocalReverse)) - for k := range graph.LocalReverse { - keys = append(keys, k) - } - sort.Strings(keys) - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(keys))); err != nil { - return nil, err - } - for _, k := range keys { - if err := writeString(k); err != nil { - return nil, err - } - children := append([]string(nil), graph.LocalReverse[k]...) - sort.Strings(children) - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(children))); err != nil { - return nil, err - } - for _, child := range children { - if err := writeString(child); err != nil { - return nil, err - } - } - } - return buf.Bytes(), nil -} - -func decodeIncrementalGraph(data []byte) (*incrementalGraph, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - var rootCount uint32 - if err := binary.Read(r, binary.LittleEndian, &rootCount); err != nil { - return nil, err - } - roots := make([]string, 0, rootCount) - for i := uint32(0); i < rootCount; i++ { - root, err := readString() - if err != nil { - return nil, err - } - roots = append(roots, root) - } - var edgeCount uint32 - if err := binary.Read(r, binary.LittleEndian, &edgeCount); err != nil { - return nil, err - } - reverse := make(map[string][]string, edgeCount) - for i := uint32(0); i < edgeCount; i++ { - k, err := readString() - if err != nil { - return nil, err - } - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - children := make([]string, 0, n) - for j := uint32(0); j < n; j++ { - child, err := readString() - if err != nil { - return nil, err - } - children = append(children, child) - } - reverse[k] = children - } - return &incrementalGraph{ - Version: version, - WD: wd, - Tags: tags, - Roots: roots, - LocalReverse: reverse, - }, nil -} - -func stringsJoin(items []string) string { - if len(items) == 0 { - return "" - } - return strings.Join(items, ",") -} diff --git a/internal/wire/incremental_graph_test.go b/internal/wire/incremental_graph_test.go deleted file mode 100644 index 8a91b54..0000000 --- a/internal/wire/incremental_graph_test.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "path/filepath" - "reflect" - "testing" - - "golang.org/x/tools/go/packages" -) - -func TestIncrementalGraphRoundTrip(t *testing.T) { - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: "/tmp/app", - Tags: "dev", - Roots: []string{"example.com/app", "example.com/other"}, - LocalReverse: map[string][]string{ - "example.com/dep": {"example.com/app"}, - "example.com/sub": {"example.com/dep", "example.com/other"}, - }, - } - data, err := encodeIncrementalGraph(graph) - if err != nil { - t.Fatalf("encodeIncrementalGraph failed: %v", err) - } - got, err := decodeIncrementalGraph(data) - if err != nil { - t.Fatalf("decodeIncrementalGraph failed: %v", err) - } - if !reflect.DeepEqual(got, graph) { - t.Fatalf("graph round-trip mismatch:\n got=%+v\nwant=%+v", got, graph) - } -} - -func TestAffectedRoots(t *testing.T) { - graph := &incrementalGraph{ - Roots: []string{"example.com/app", "example.com/other"}, - LocalReverse: map[string][]string{ - "example.com/dep": {"example.com/app"}, - "example.com/sub": {"example.com/dep", "example.com/other"}, - }, - } - got := affectedRoots(graph, []string{"example.com/sub"}) - want := []string{"example.com/app", "example.com/other"} - if !reflect.DeepEqual(got, want) { - t.Fatalf("affectedRoots=%v want %v", got, want) - } -} - -func TestBuildIncrementalGraph(t *testing.T) { - root := t.TempDir() - writeFile(t, filepath.Join(root, "go.mod"), "module example.com/test\n\ngo 1.21\n") - - appFile := filepath.Join(root, "app", "app.go") - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, appFile, "package app\n") - writeFile(t, depFile, "package dep\n") - - dep := &packages.Package{ - PkgPath: "example.com/test/dep", - CompiledGoFiles: []string{depFile}, - GoFiles: []string{depFile}, - Imports: map[string]*packages.Package{}, - } - app := &packages.Package{ - PkgPath: "example.com/test/app", - CompiledGoFiles: []string{appFile}, - GoFiles: []string{appFile}, - Imports: map[string]*packages.Package{ - "example.com/test/dep": dep, - }, - } - - graph := buildIncrementalGraph(root, "", []*packages.Package{app}) - if len(graph.Roots) != 1 || graph.Roots[0] != app.PkgPath { - t.Fatalf("unexpected roots: %v", graph.Roots) - } - got := graph.LocalReverse[dep.PkgPath] - want := []string{app.PkgPath} - if !reflect.DeepEqual(got, want) { - t.Fatalf("unexpected reverse edges: got=%v want=%v", got, want) - } -} diff --git a/internal/wire/incremental_manifest.go b/internal/wire/incremental_manifest.go deleted file mode 100644 index 8fab10e..0000000 --- a/internal/wire/incremental_manifest.go +++ /dev/null @@ -1,1158 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "context" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/token" - "os" - "path/filepath" - "sort" - "strings" - - "golang.org/x/tools/go/packages" -) - -const incrementalManifestVersion = "wire-incremental-manifest-v3" - -type incrementalManifest struct { - Version string - WD string - Tags string - Prefix string - HeaderHash string - EnvHash string - Patterns []string - LocalPackages []packageFingerprint - ExternalPkgs []externalPackageExport - ExternalFiles []cacheFile - ExtraFiles []cacheFile - Outputs []incrementalOutput -} - -type externalPackageExport struct { - PkgPath string - ExportFile string -} - -type incrementalOutput struct { - PkgPath string - OutputPath string - ContentKey string -} - -type incrementalPreloadState struct { - selectorKey string - manifest *incrementalManifest - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -type incrementalPreloadValidation struct { - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -const touchedValidationVersion = "wire-touched-validation-v1" - -func readPreloadIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) ([]GenerateResult, bool) { - state, ok := prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) - return readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, state, ok) -} - -func readPreloadIncrementalManifestResultsFromState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, ok bool) ([]GenerateResult, bool) { - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=no_manifest") - return nil, false - } - if state.valid { - validateStart := timeNow() - if len(state.touched) > 0 { - debugf(ctx, "incremental.preload_manifest touched=%s", strings.Join(state.touched, ",")) - } - if err := validateIncrementalPreloadTouchedPackages(ctx, wd, env, opts, state.currentLocal, state.touched); err != nil { - logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) - if shouldBypassIncrementalManifestAfterFastPathError(err) { - invalidateIncrementalPreloadState(state) - } - debugf(ctx, "incremental.preload_manifest miss reason=touched_validation") - return nil, false - } - logTiming(ctx, "incremental.preload_manifest.validate_touched", validateStart) - outputsStart := timeNow() - results, ok := incrementalManifestOutputs(state.manifest) - logTiming(ctx, "incremental.preload_manifest.outputs", outputsStart) - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=outputs") - return nil, false - } - if manifestNeedsLocalRefresh(state.manifest.LocalPackages, state.currentLocal) { - refreshed := *state.manifest - refreshed.LocalPackages = append([]packageFingerprint(nil), state.currentLocal...) - writeIncrementalManifestFile(state.selectorKey, &refreshed) - writeIncrementalManifestFile(incrementalManifestStateKey(state.selectorKey, refreshed.LocalPackages), &refreshed) - } - debugf(ctx, "incremental.preload_manifest hit outputs=%d", len(results)) - return results, true - } else if archived := readStateIncrementalManifest(state.selectorKey, state.currentLocal); archived != nil { - if validation := incrementalManifestPreloadValid(ctx, archived, wd, env, patterns, opts); validation.valid { - results, ok := incrementalManifestOutputs(archived) - if !ok { - debugf(ctx, "incremental.preload_manifest miss reason=state_outputs") - return nil, false - } - writeIncrementalManifestFile(state.selectorKey, archived) - debugf(ctx, "incremental.preload_manifest state_hit outputs=%d", len(results)) - return results, true - } - debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) - return nil, false - } else { - debugf(ctx, "incremental.preload_manifest miss reason=%s", state.reason) - return nil, false - } -} - -func prepareIncrementalPreloadState(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (*incrementalPreloadState, bool) { - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - manifest, ok := readIncrementalManifest(selectorKey) - if !ok { - return nil, false - } - validation := incrementalManifestPreloadValid(ctx, manifest, wd, env, patterns, opts) - return &incrementalPreloadState{ - selectorKey: selectorKey, - manifest: manifest, - valid: validation.valid, - currentLocal: validation.currentLocal, - touched: validation.touched, - reason: validation.reason, - }, true -} - -func readIncrementalManifestResults(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot) ([]GenerateResult, bool) { - if snapshot == nil || snapshot.stats.changed != 0 { - return nil, false - } - key := incrementalManifestSelectorKey(wd, env, patterns, opts) - manifest, ok := readIncrementalManifest(key) - if !ok || !incrementalManifestValid(manifest, wd, env, patterns, opts, pkgs) { - return nil, false - } - results := make([]GenerateResult, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - content, ok := readCache(out.ContentKey) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - Content: content, - }) - } - debugf(ctx, "incremental.manifest hit outputs=%d", len(results)) - return results, true -} - -func writeIncrementalManifest(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { - writeIncrementalManifestWithOptions(wd, env, patterns, opts, pkgs, snapshot, generated, true) -} - -func writeIncrementalManifestWithOptions(wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult, includeExternalFiles bool) { - if snapshot == nil || len(generated) == 0 { - return - } - scope := runCacheScope(wd, patterns) - externalPkgs := buildExternalPackageExports(wd, pkgs) - var externalFiles []cacheFile - if includeExternalFiles { - var err error - externalFiles, err = buildExternalPackageFiles(wd, pkgs) - if err != nil { - return - } - } - manifest := &incrementalManifest{ - Version: incrementalManifestVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - LocalPackages: snapshotPackageFingerprints(snapshot), - ExternalPkgs: externalPkgs, - ExternalFiles: externalFiles, - ExtraFiles: extraCacheFiles(wd), - } - for _, out := range generated { - if len(out.Content) == 0 || out.OutputPath == "" { - continue - } - contentKey := incrementalContentKey(out.Content) - writeCache(contentKey, out.Content) - manifest.Outputs = append(manifest.Outputs, incrementalOutput{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - ContentKey: contentKey, - }) - } - if len(manifest.Outputs) == 0 { - return - } - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - stateKey := incrementalManifestStateKey(selectorKey, manifest.LocalPackages) - writeIncrementalManifestFile(selectorKey, manifest) - writeIncrementalManifestFile(stateKey, manifest) -} - -func incrementalManifestSelectorKey(wd string, env []string, patterns []string, opts *GenerateOptions) string { - h := sha256.New() - h.Write([]byte(incrementalManifestVersion)) - h.Write([]byte{0}) - h.Write([]byte(runCacheScope(wd, patterns))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - h.Write([]byte(opts.Tags)) - h.Write([]byte{0}) - h.Write([]byte(opts.PrefixOutputFile)) - h.Write([]byte{0}) - h.Write([]byte(headerHash(opts.Header))) - h.Write([]byte{0}) - for _, p := range normalizePatternsForScope(wd, packageCacheScope(wd), patterns) { - h.Write([]byte(p)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func snapshotPackageFingerprints(snapshot *incrementalFingerprintSnapshot) []packageFingerprint { - if snapshot == nil || len(snapshot.fingerprints) == 0 { - return nil - } - paths := make([]string, 0, len(snapshot.fingerprints)) - for path := range snapshot.fingerprints { - paths = append(paths, path) - } - sort.Strings(paths) - out := make([]packageFingerprint, 0, len(paths)) - for _, path := range paths { - if fp := snapshot.fingerprints[path]; fp != nil { - out = append(out, *fp) - } - } - return out -} - -func buildExternalPackageFiles(wd string, pkgs []*packages.Package) ([]cacheFile, error) { - moduleRoot := findModuleRoot(wd) - seen := make(map[string]struct{}) - var files []string - for _, pkg := range collectAllPackages(pkgs) { - if classifyPackageLocation(moduleRoot, pkg) == "local" { - continue - } - names := pkg.CompiledGoFiles - if len(names) == 0 { - names = pkg.GoFiles - } - for _, name := range names { - clean := filepath.Clean(name) - if _, ok := seen[clean]; ok { - continue - } - seen[clean] = struct{}{} - files = append(files, clean) - } - } - sort.Strings(files) - return buildCacheFiles(files) -} - -func buildExternalPackageExports(wd string, pkgs []*packages.Package) []externalPackageExport { - out := make([]externalPackageExport, 0) - for _, pkg := range collectAllPackages(pkgs) { - if pkg == nil || pkg.PkgPath == "" || pkg.ExportFile == "" { - continue - } - out = append(out, externalPackageExport{ - PkgPath: pkg.PkgPath, - ExportFile: pkg.ExportFile, - }) - } - sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) - return out -} - -func incrementalManifestValid(manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions, pkgs []*packages.Package) bool { - if manifest == nil || manifest.Version != incrementalManifestVersion { - return false - } - if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return false - } - if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return false - } - normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) - if len(manifest.Patterns) != len(normalizedPatterns) { - return false - } - for i, p := range normalizedPatterns { - if manifest.Patterns[i] != p { - return false - } - } - currentExternal, err := buildExternalPackageFiles(wd, pkgs) - if err != nil || len(currentExternal) != len(manifest.ExternalFiles) { - return false - } - for i := range currentExternal { - if currentExternal[i] != manifest.ExternalFiles[i] { - return false - } - } - if len(manifest.ExtraFiles) > 0 { - current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) - if err != nil || len(current) != len(manifest.ExtraFiles) { - return false - } - for i := range current { - if current[i] != manifest.ExtraFiles[i] { - return false - } - } - } - return len(manifest.Outputs) > 0 -} - -func incrementalManifestPreloadValid(ctx context.Context, manifest *incrementalManifest, wd string, env []string, patterns []string, opts *GenerateOptions) incrementalPreloadValidation { - if manifest == nil || manifest.Version != incrementalManifestVersion { - return incrementalPreloadValidation{reason: "version"} - } - if manifest.WD != runCacheScope(wd, patterns) || manifest.Tags != opts.Tags || manifest.Prefix != opts.PrefixOutputFile { - return incrementalPreloadValidation{reason: "config"} - } - if manifest.HeaderHash != headerHash(opts.Header) || manifest.EnvHash != envHash(env) { - return incrementalPreloadValidation{reason: "env"} - } - normalizedPatterns := normalizePatternsForScope(wd, packageCacheScope(wd), patterns) - if len(manifest.Patterns) != len(normalizedPatterns) { - return incrementalPreloadValidation{reason: "patterns.length"} - } - for i, p := range normalizedPatterns { - if manifest.Patterns[i] != p { - return incrementalPreloadValidation{reason: "patterns.value"} - } - } - if len(manifest.ExtraFiles) > 0 { - extraStart := timeNow() - current, err := buildCacheFilesFromMeta(manifest.ExtraFiles) - logTiming(ctx, "incremental.preload_manifest.validate_extra_files", extraStart) - if err != nil || len(current) != len(manifest.ExtraFiles) { - return incrementalPreloadValidation{reason: "extra_files"} - } - for i := range current { - if current[i] != manifest.ExtraFiles[i] { - return incrementalPreloadValidation{reason: "extra_files.diff"} - } - } - } - localStart := timeNow() - packagesState := incrementalManifestCurrentLocalPackages(ctx, manifest.LocalPackages) - logTiming(ctx, "incremental.preload_manifest.validate_local_packages", localStart) - if !packagesState.valid { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "local_packages." + packagesState.reason, - } - } - if len(manifest.ExternalFiles) > 0 { - externalStart := timeNow() - current, err := buildCacheFilesFromMeta(manifest.ExternalFiles) - logTiming(ctx, "incremental.preload_manifest.validate_external_files", externalStart) - if err != nil || len(current) != len(manifest.ExternalFiles) { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "external_files", - } - } - for i := range current { - if current[i] != manifest.ExternalFiles[i] { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "external_files.diff", - } - } - } - } - if len(manifest.Outputs) == 0 { - return incrementalPreloadValidation{ - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - reason: "outputs", - } - } - return incrementalPreloadValidation{ - valid: true, - currentLocal: packagesState.currentLocal, - touched: packagesState.touched, - } -} - -type incrementalLocalPackagesState struct { - valid bool - currentLocal []packageFingerprint - touched []string - reason string -} - -func incrementalManifestCurrentLocalPackages(ctx context.Context, local []packageFingerprint) incrementalLocalPackagesState { - currentState := make([]packageFingerprint, 0, len(local)) - touched := make([]string, 0, len(local)) - var firstReason string - for _, fp := range local { - if len(fp.Files) == 0 { - if firstReason == "" { - firstReason = fp.PkgPath + ".files" - } - continue - } - storedFiles := filesFromMeta(fp.Files) - if len(storedFiles) == 0 { - if firstReason == "" { - firstReason = fp.PkgPath + ".stored_files" - } - continue - } - currentMeta, err := buildCacheFiles(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".meta_error" - } - continue - } - currentFP := fp - currentFP.Files = append([]cacheFile(nil), currentMeta...) - sameMeta := len(currentMeta) == len(fp.Files) - if sameMeta { - for i := range currentMeta { - if currentMeta[i] != fp.Files[i] { - sameMeta = false - break - } - } - } - if !sameMeta { - if diffs := describeCacheFileDiffs(fp.Files, currentMeta); len(diffs) > 0 { - debugf(ctx, "incremental.preload_manifest local_pkg=%s meta_diff=%s", fp.PkgPath, strings.Join(diffs, "; ")) - } - contentHash, err := hashFiles(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s content_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".content_error" - } - continue - } - currentFP.ContentHash = contentHash - if contentHash != fp.ContentHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_content=%s current_content=%s hash_files=%s", fp.PkgPath, fp.ContentHash, contentHash, strings.Join(storedFiles, ",")) - shapeHash, err := packageShapeHash(storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s shape_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_error" - } - continue - } - currentFP.ShapeHash = shapeHash - if shapeHash != fp.ShapeHash { - debugf(ctx, "incremental.preload_manifest local_pkg=%s stored_shape=%s current_shape=%s files=%s", fp.PkgPath, fp.ShapeHash, shapeHash, strings.Join(storedFiles, ",")) - if firstReason == "" { - firstReason = fp.PkgPath + ".shape_mismatch" - } - } else { - debugf(ctx, "incremental.preload_manifest local_pkg=%s content_changed_shape_unchanged", fp.PkgPath) - touched = append(touched, fp.PkgPath) - } - } - } - currentDirs, dirsChanged, err := packageDirectoryMetaChanged(fp, storedFiles) - if err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_meta_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".dir_meta_error" - } - continue - } - currentFP.Dirs = currentDirs - if dirsChanged { - if changed, err := packageDirectoryIntroducedRelevantFiles(fp.Files); err != nil { - debugf(ctx, "incremental.preload_manifest local_pkg=%s dir_scan_error=%v", fp.PkgPath, err) - if firstReason == "" { - firstReason = fp.PkgPath + ".dir_scan_error" - } - continue - } else if changed { - debugf(ctx, "incremental.preload_manifest local_pkg=%s introduced_relevant_files=true", fp.PkgPath) - if firstReason == "" { - firstReason = fp.PkgPath + ".introduced_relevant_files" - } - } - } - currentState = append(currentState, currentFP) - } - if firstReason != "" { - return incrementalLocalPackagesState{ - currentLocal: currentState, - touched: touched, - reason: firstReason, - } - } - sort.Strings(touched) - return incrementalLocalPackagesState{ - valid: true, - currentLocal: currentState, - touched: touched, - } -} - -func validateIncrementalPreloadTouchedPackages(ctx context.Context, wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) error { - if len(touched) == 0 { - return nil - } - cacheKey := touchedValidationKey(wd, env, opts, local, touched) - if cacheKey != "" { - cacheHitStart := timeNow() - if _, ok := readCache(cacheKey); ok { - logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_hit", cacheHitStart) - return nil - } - } - cfg := &packages.Config{ - Context: ctx, - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedExportsFile | packages.NeedTypes | packages.NeedSyntax | packages.NeedTypesSizes, - Dir: wd, - Env: env, - BuildFlags: []string{"-tags=wireinject"}, - Fset: token.NewFileSet(), - } - if len(opts.Tags) > 0 { - cfg.BuildFlags[0] += " " + opts.Tags - } - loadStart := timeNow() - pkgs, err := packages.Load(cfg, touched...) - logTiming(ctx, "incremental.preload_manifest.validate_touched_load", loadStart) - if err != nil { - return err - } - errorsStart := timeNow() - byPath := make(map[string]*packages.Package, len(pkgs)) - for _, pkg := range pkgs { - if pkg != nil { - byPath[pkg.PkgPath] = pkg - } - } - for _, path := range touched { - if pkg := byPath[path]; pkg != nil && len(pkg.Errors) > 0 { - logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) - return formatLocalTypeCheckError(wd, pkg.PkgPath, pkg.Errors) - } - } - logTiming(ctx, "incremental.preload_manifest.validate_touched_errors", errorsStart) - if cacheKey != "" { - cacheWriteStart := timeNow() - writeCache(cacheKey, []byte("ok")) - logTiming(ctx, "incremental.preload_manifest.validate_touched_cache_write", cacheWriteStart) - } - return nil -} - -func touchedValidationKey(wd string, env []string, opts *GenerateOptions, local []packageFingerprint, touched []string) string { - if len(touched) == 0 { - return "" - } - byPath := fingerprintsFromSlice(local) - h := sha256.New() - h.Write([]byte(touchedValidationVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(envHash(env))) - h.Write([]byte{0}) - if opts != nil { - h.Write([]byte(opts.Tags)) - } - h.Write([]byte{0}) - for _, pkgPath := range touched { - fp := byPath[pkgPath] - if fp == nil || fp.ContentHash == "" { - return "" - } - h.Write([]byte(pkgPath)) - h.Write([]byte{0}) - h.Write([]byte(fp.ContentHash)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalManifestOutputs(manifest *incrementalManifest) ([]GenerateResult, bool) { - results := make([]GenerateResult, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - content, ok := readCache(out.ContentKey) - if !ok { - return nil, false - } - results = append(results, GenerateResult{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - Content: content, - }) - } - return results, true -} - -func readStateIncrementalManifest(selectorKey string, local []packageFingerprint) *incrementalManifest { - if len(local) == 0 { - return nil - } - stateKey := incrementalManifestStateKey(selectorKey, local) - manifest, ok := readIncrementalManifest(stateKey) - if !ok { - return nil - } - return manifest -} - -func incrementalManifestStateKey(selectorKey string, local []packageFingerprint) string { - h := sha256.New() - h.Write([]byte(selectorKey)) - h.Write([]byte{0}) - for _, fp := range snapshotPackageFingerprints(&incrementalFingerprintSnapshot{fingerprints: fingerprintsFromSlice(local)}) { - h.Write([]byte(fp.PkgPath)) - h.Write([]byte{0}) - h.Write([]byte(fp.ShapeHash)) - h.Write([]byte{0}) - } - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func fingerprintsFromSlice(local []packageFingerprint) map[string]*packageFingerprint { - if len(local) == 0 { - return nil - } - out := make(map[string]*packageFingerprint, len(local)) - for i := range local { - fp := local[i] - out[fp.PkgPath] = &fp - } - return out -} - -func filesFromMeta(files []cacheFile) []string { - out := make([]string, 0, len(files)) - for _, f := range files { - out = append(out, filepath.Clean(f.Path)) - } - sort.Strings(out) - return out -} - -func describeCacheFileDiffs(stored []cacheFile, current []cacheFile) []string { - if len(stored) == 0 && len(current) == 0 { - return nil - } - storedByPath := make(map[string]cacheFile, len(stored)) - currentByPath := make(map[string]cacheFile, len(current)) - for _, file := range stored { - storedByPath[filepath.Clean(file.Path)] = file - } - for _, file := range current { - currentByPath[filepath.Clean(file.Path)] = file - } - paths := make([]string, 0, len(storedByPath)+len(currentByPath)) - seen := make(map[string]struct{}, len(storedByPath)+len(currentByPath)) - for path := range storedByPath { - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - paths = append(paths, path) - } - for path := range currentByPath { - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - paths = append(paths, path) - } - sort.Strings(paths) - diffs := make([]string, 0, len(paths)) - for _, path := range paths { - storedFile, storedOK := storedByPath[path] - currentFile, currentOK := currentByPath[path] - switch { - case !storedOK: - diffs = append(diffs, fmt.Sprintf("%s added size=%d mtime=%d", path, currentFile.Size, currentFile.ModTime)) - case !currentOK: - diffs = append(diffs, fmt.Sprintf("%s removed size=%d mtime=%d", path, storedFile.Size, storedFile.ModTime)) - case storedFile != currentFile: - diffs = append(diffs, fmt.Sprintf("%s size:%d->%d mtime:%d->%d", path, storedFile.Size, currentFile.Size, storedFile.ModTime, currentFile.ModTime)) - } - } - return diffs -} - -func manifestNeedsLocalRefresh(stored []packageFingerprint, current []packageFingerprint) bool { - if len(stored) != len(current) { - return false - } - for i := range stored { - if stored[i].PkgPath != current[i].PkgPath { - return false - } - if stored[i].ContentHash == "" && current[i].ContentHash != "" { - return true - } - if len(stored[i].Dirs) == 0 && len(current[i].Dirs) > 0 { - return true - } - } - return false -} - -func packageDirectoryMetaChanged(fp packageFingerprint, storedFiles []string) ([]cacheFile, bool, error) { - dirs := packageFingerprintDirs(storedFiles) - if len(dirs) == 0 { - return nil, false, nil - } - current, err := buildCacheFiles(dirs) - if err != nil { - return nil, false, err - } - if len(fp.Dirs) != len(current) { - return current, true, nil - } - for i := range current { - if current[i] != fp.Dirs[i] { - return current, true, nil - } - } - return current, false, nil -} - -func packageDirectoryIntroducedRelevantFiles(files []cacheFile) (bool, error) { - dirs := make(map[string]struct{}) - old := make(map[string]struct{}, len(files)) - for _, f := range files { - path := filepath.Clean(f.Path) - dirs[filepath.Dir(path)] = struct{}{} - old[path] = struct{}{} - } - for dir := range dirs { - entries, err := os.ReadDir(dir) - if err != nil { - return false, err - } - for _, entry := range entries { - if entry.IsDir() { - continue - } - name := entry.Name() - if !strings.HasSuffix(name, ".go") { - continue - } - if strings.HasSuffix(name, "_test.go") { - continue - } - if strings.HasSuffix(name, "wire_gen.go") { - continue - } - path := filepath.Clean(filepath.Join(dir, name)) - if _, ok := old[path]; !ok { - return true, nil - } - } - } - return false, nil -} - -func incrementalManifestPath(key string) string { - return filepath.Join(cacheDir(), key+".iman") -} - -func readIncrementalManifest(key string) (*incrementalManifest, bool) { - data, err := osReadFile(incrementalManifestPath(key)) - if err != nil { - return nil, false - } - manifest, err := decodeIncrementalManifest(data) - if err != nil { - return nil, false - } - return manifest, true -} - -func writeIncrementalManifestFile(key string, manifest *incrementalManifest) { - data, err := encodeIncrementalManifest(manifest) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".iman-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalManifestPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func removeIncrementalManifestFile(key string) { - if key == "" { - return - } - _ = osRemove(incrementalManifestPath(key)) -} - -func encodeIncrementalManifest(manifest *incrementalManifest) ([]byte, error) { - var buf bytes.Buffer - if manifest == nil { - return nil, fmt.Errorf("nil incremental manifest") - } - writeString := func(s string) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(s))); err != nil { - return err - } - _, err := buf.WriteString(s) - return err - } - writeCacheFiles := func(files []cacheFile) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(files))); err != nil { - return err - } - for _, f := range files { - if err := writeString(f.Path); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.Size); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, f.ModTime); err != nil { - return err - } - } - return nil - } - writeExternalPkgs := func(pkgs []externalPackageExport) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(pkgs))); err != nil { - return err - } - for _, pkg := range pkgs { - if err := writeString(pkg.PkgPath); err != nil { - return err - } - if err := writeString(pkg.ExportFile); err != nil { - return err - } - } - return nil - } - writeFingerprints := func(fps []packageFingerprint) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fps))); err != nil { - return err - } - for _, fp := range fps { - for _, s := range []string{fp.Version, fp.WD, fp.Tags, fp.PkgPath, fp.ShapeHash} { - if err := writeString(s); err != nil { - return err - } - } - if err := writeCacheFiles(fp.Files); err != nil { - return err - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(fp.LocalImports))); err != nil { - return err - } - for _, imp := range fp.LocalImports { - if err := writeString(imp); err != nil { - return err - } - } - } - return nil - } - writeOutputs := func(outputs []incrementalOutput) error { - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(outputs))); err != nil { - return err - } - for _, out := range outputs { - for _, s := range []string{out.PkgPath, out.OutputPath, out.ContentKey} { - if err := writeString(s); err != nil { - return err - } - } - } - return nil - } - for _, s := range []string{manifest.Version, manifest.WD, manifest.Tags, manifest.Prefix, manifest.HeaderHash, manifest.EnvHash} { - if err := writeString(s); err != nil { - return nil, err - } - } - if err := binary.Write(&buf, binary.LittleEndian, uint32(len(manifest.Patterns))); err != nil { - return nil, err - } - for _, p := range manifest.Patterns { - if err := writeString(p); err != nil { - return nil, err - } - } - if err := writeFingerprints(manifest.LocalPackages); err != nil { - return nil, err - } - if err := writeExternalPkgs(manifest.ExternalPkgs); err != nil { - return nil, err - } - if err := writeCacheFiles(manifest.ExternalFiles); err != nil { - return nil, err - } - if err := writeCacheFiles(manifest.ExtraFiles); err != nil { - return nil, err - } - if err := writeOutputs(manifest.Outputs); err != nil { - return nil, err - } - return buf.Bytes(), nil -} - -func decodeIncrementalManifest(data []byte) (*incrementalManifest, error) { - r := bytes.NewReader(data) - readString := func() (string, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return "", err - } - buf := make([]byte, n) - if _, err := r.Read(buf); err != nil { - return "", err - } - return string(buf), nil - } - readCacheFiles := func() ([]cacheFile, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]cacheFile, 0, n) - for i := uint32(0); i < n; i++ { - path, err := readString() - if err != nil { - return nil, err - } - var size int64 - if err := binary.Read(r, binary.LittleEndian, &size); err != nil { - return nil, err - } - var modTime int64 - if err := binary.Read(r, binary.LittleEndian, &modTime); err != nil { - return nil, err - } - out = append(out, cacheFile{Path: path, Size: size, ModTime: modTime}) - } - return out, nil - } - readExternalPkgs := func() ([]externalPackageExport, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]externalPackageExport, 0, n) - for i := uint32(0); i < n; i++ { - pkgPath, err := readString() - if err != nil { - return nil, err - } - exportFile, err := readString() - if err != nil { - return nil, err - } - out = append(out, externalPackageExport{PkgPath: pkgPath, ExportFile: exportFile}) - } - return out, nil - } - readFingerprints := func() ([]packageFingerprint, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]packageFingerprint, 0, n) - for i := uint32(0); i < n; i++ { - version, err := readString() - if err != nil { - return nil, err - } - wd, err := readString() - if err != nil { - return nil, err - } - tags, err := readString() - if err != nil { - return nil, err - } - pkgPath, err := readString() - if err != nil { - return nil, err - } - shapeHash, err := readString() - if err != nil { - return nil, err - } - files, err := readCacheFiles() - if err != nil { - return nil, err - } - var importCount uint32 - if err := binary.Read(r, binary.LittleEndian, &importCount); err != nil { - return nil, err - } - localImports := make([]string, 0, importCount) - for j := uint32(0); j < importCount; j++ { - imp, err := readString() - if err != nil { - return nil, err - } - localImports = append(localImports, imp) - } - out = append(out, packageFingerprint{ - Version: version, - WD: wd, - Tags: tags, - PkgPath: pkgPath, - ShapeHash: shapeHash, - Files: files, - LocalImports: localImports, - }) - } - return out, nil - } - readOutputs := func() ([]incrementalOutput, error) { - var n uint32 - if err := binary.Read(r, binary.LittleEndian, &n); err != nil { - return nil, err - } - out := make([]incrementalOutput, 0, n) - for i := uint32(0); i < n; i++ { - pkgPath, err := readString() - if err != nil { - return nil, err - } - outputPath, err := readString() - if err != nil { - return nil, err - } - contentKey, err := readString() - if err != nil { - return nil, err - } - out = append(out, incrementalOutput{PkgPath: pkgPath, OutputPath: outputPath, ContentKey: contentKey}) - } - return out, nil - } - fields := make([]string, 6) - for i := range fields { - s, err := readString() - if err != nil { - return nil, err - } - fields[i] = s - } - var patternCount uint32 - if err := binary.Read(r, binary.LittleEndian, &patternCount); err != nil { - return nil, err - } - patterns := make([]string, 0, patternCount) - for i := uint32(0); i < patternCount; i++ { - p, err := readString() - if err != nil { - return nil, err - } - patterns = append(patterns, p) - } - localPackages, err := readFingerprints() - if err != nil { - return nil, err - } - externalPkgs, err := readExternalPkgs() - if err != nil { - return nil, err - } - externalFiles, err := readCacheFiles() - if err != nil { - return nil, err - } - extraFiles, err := readCacheFiles() - if err != nil { - return nil, err - } - outputs, err := readOutputs() - if err != nil { - return nil, err - } - return &incrementalManifest{ - Version: fields[0], - WD: fields[1], - Tags: fields[2], - Prefix: fields[3], - HeaderHash: fields[4], - EnvHash: fields[5], - Patterns: patterns, - LocalPackages: localPackages, - ExternalPkgs: externalPkgs, - ExternalFiles: externalFiles, - ExtraFiles: extraFiles, - Outputs: outputs, - }, nil -} - -func incrementalContentKey(content []byte) string { - sum := sha256.Sum256(content) - return fmt.Sprintf("%x", sum[:]) -} diff --git a/internal/wire/incremental_session.go b/internal/wire/incremental_session.go deleted file mode 100644 index 2fdaa2b..0000000 --- a/internal/wire/incremental_session.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "encoding/hex" - "go/ast" - "go/token" - "path/filepath" - "strings" - "sync" -) - -type incrementalSession struct { - fset *token.FileSet - mu sync.Mutex - parsedDeps map[string]cachedParsedFile -} - -type cachedParsedFile struct { - hash string - file *ast.File -} - -var incrementalSessions sync.Map - -func clearIncrementalSessions() { - incrementalSessions.Range(func(key, _ any) bool { - incrementalSessions.Delete(key) - return true - }) -} - -func sessionKey(wd string, env []string, tags string) string { - var b strings.Builder - b.WriteString(packageCacheScope(wd)) - b.WriteByte('\n') - b.WriteString(tags) - b.WriteByte('\n') - for _, entry := range env { - b.WriteString(entry) - b.WriteByte('\x00') - } - return b.String() -} - -func getIncrementalSession(wd string, env []string, tags string) *incrementalSession { - key := sessionKey(wd, env, tags) - if session, ok := incrementalSessions.Load(key); ok { - return session.(*incrementalSession) - } - session := &incrementalSession{ - fset: token.NewFileSet(), - parsedDeps: make(map[string]cachedParsedFile), - } - actual, _ := incrementalSessions.LoadOrStore(key, session) - return actual.(*incrementalSession) -} - -func (s *incrementalSession) getParsedDep(filename string, src []byte) (*ast.File, bool) { - if s == nil { - return nil, false - } - hash := hashSource(src) - s.mu.Lock() - defer s.mu.Unlock() - entry, ok := s.parsedDeps[filepath.Clean(filename)] - if !ok || entry.hash != hash { - return nil, false - } - return entry.file, true -} - -func (s *incrementalSession) storeParsedDep(filename string, src []byte, file *ast.File) { - if s == nil || file == nil { - return - } - s.mu.Lock() - defer s.mu.Unlock() - s.parsedDeps[filepath.Clean(filename)] = cachedParsedFile{ - hash: hashSource(src), - file: file, - } -} - -func hashSource(src []byte) string { - sum := sha256.Sum256(src) - return hex.EncodeToString(sum[:]) -} diff --git a/internal/wire/incremental_summary.go b/internal/wire/incremental_summary.go deleted file mode 100644 index 934f637..0000000 --- a/internal/wire/incremental_summary.go +++ /dev/null @@ -1,656 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "fmt" - "go/ast" - "go/types" - "path/filepath" - "sort" - - "golang.org/x/tools/go/packages" -) - -const incrementalSummaryVersion = "wire-incremental-summary-v1" - -type packageSummary struct { - Version string - WD string - Tags string - PkgPath string - ShapeHash string - LocalImports []string - ProviderSets []providerSetSummary - Injectors []injectorSummary -} - -type providerSetSummary struct { - VarName string - Providers []providerSummary - Imports []providerSetRefSummary - Bindings []ifaceBindingSummary - Values []string - Fields []fieldSummary - InputTypes []string -} - -type providerSummary struct { - PkgPath string - Name string - Args []providerInputSummary - Out []string - Varargs bool - IsStruct bool - HasCleanup bool - HasErr bool -} - -type providerInputSummary struct { - Type string - FieldName string -} - -type providerSetRefSummary struct { - PkgPath string - VarName string -} - -type ifaceBindingSummary struct { - Iface string - Provided string -} - -type fieldSummary struct { - PkgPath string - Parent string - Name string - Out []string -} - -type injectorSummary struct { - Name string - Inputs []string - Output string - Build providerSetSummary -} - -type packageSummarySnapshot struct { - Changed map[string]*packageSummary - Unchanged map[string]*packageSummary -} - -func incrementalSummaryKey(wd string, tags string, pkgPath string) string { - h := sha256.New() - h.Write([]byte(incrementalSummaryVersion)) - h.Write([]byte{0}) - h.Write([]byte(packageCacheScope(wd))) - h.Write([]byte{0}) - h.Write([]byte(tags)) - h.Write([]byte{0}) - h.Write([]byte(pkgPath)) - return fmt.Sprintf("%x", h.Sum(nil)) -} - -func incrementalSummaryPath(key string) string { - return filepath.Join(cacheDir(), key+".isum") -} - -func readIncrementalPackageSummary(key string) (*packageSummary, bool) { - data, err := osReadFile(incrementalSummaryPath(key)) - if err != nil { - return nil, false - } - summary, err := decodeIncrementalSummary(data) - if err != nil { - return nil, false - } - return summary, true -} - -func writeIncrementalPackageSummary(key string, summary *packageSummary) { - data, err := encodeIncrementalSummary(summary) - if err != nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, key+".isum-") - if err != nil { - return - } - _, writeErr := tmp.Write(data) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), incrementalSummaryPath(key)); err != nil { - osRemove(tmp.Name()) - } -} - -func writeIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) { - writeIncrementalPackageSummariesWithSummary(loader, pkgs, nil, nil) -} - -func writeIncrementalPackageSummariesWithSummary(loader *lazyLoader, pkgs []*packages.Package, summary *summaryProviderResolver, only map[string]struct{}) { - if loader == nil || len(pkgs) == 0 { - return - } - moduleRoot := findModuleRoot(loader.wd) - all := collectAllPackages(pkgs) - for path, pkg := range loader.loaded { - if pkg != nil { - all[path] = pkg - } - } - allPkgs := make([]*packages.Package, 0, len(all)) - for _, pkg := range all { - allPkgs = append(allPkgs, pkg) - } - oc := newObjectCacheWithLoader(allPkgs, loader, nil, summary) - for _, pkg := range all { - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - if len(only) > 0 { - if _, ok := only[pkg.PkgPath]; !ok { - continue - } - } - if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - continue - } - summary, err := buildPackageSummary(loader, oc, pkg) - if err != nil { - continue - } - writeIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath), summary) - } -} - -func collectIncrementalPackageSummaries(loader *lazyLoader, pkgs []*packages.Package) *packageSummarySnapshot { - if loader == nil || loader.fingerprints == nil { - return nil - } - snapshot := &packageSummarySnapshot{ - Changed: make(map[string]*packageSummary), - Unchanged: make(map[string]*packageSummary), - } - changed := make(map[string]struct{}, len(loader.fingerprints.changed)) - for _, path := range loader.fingerprints.changed { - changed[path] = struct{}{} - } - moduleRoot := findModuleRoot(loader.wd) - oc := newObjectCache(pkgs, loader) - for _, pkg := range collectAllPackages(pkgs) { - if pkg == nil { - continue - } - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - if _, ok := changed[pkg.PkgPath]; ok { - if pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - loaded, errs := oc.ensurePackage(pkg.PkgPath) - if len(errs) > 0 { - continue - } - pkg = loaded - } - if pkg == nil || pkg.TypesInfo == nil || len(pkg.Syntax) == 0 { - continue - } - summary, err := buildPackageSummary(loader, oc, pkg) - if err != nil { - continue - } - snapshot.Changed[pkg.PkgPath] = summary - continue - } - if summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(loader.wd, loader.tags, pkg.PkgPath)); ok { - snapshot.Unchanged[pkg.PkgPath] = summary - } - } - return snapshot -} - -func buildPackageSummary(loader *lazyLoader, oc *objectCache, pkg *packages.Package) (*packageSummary, error) { - if loader == nil || oc == nil || pkg == nil { - return nil, fmt.Errorf("missing loader, object cache, or package") - } - summary := &packageSummary{ - Version: incrementalSummaryVersion, - WD: filepath.Clean(loader.wd), - Tags: loader.tags, - PkgPath: pkg.PkgPath, - } - if snapshot := loader.fingerprints; snapshot != nil { - if fp := snapshot.fingerprints[pkg.PkgPath]; fp != nil { - summary.ShapeHash = fp.ShapeHash - summary.LocalImports = append(summary.LocalImports, fp.LocalImports...) - } - } - scope := pkg.Types.Scope() - for _, name := range scope.Names() { - obj := scope.Lookup(name) - if !isProviderSetType(obj.Type()) { - continue - } - item, errs := oc.get(obj) - if len(errs) > 0 { - continue - } - pset, ok := item.(*ProviderSet) - if !ok { - continue - } - summary.ProviderSets = append(summary.ProviderSets, summarizeProviderSet(pset)) - } - sort.Slice(summary.ProviderSets, func(i, j int) bool { - return summary.ProviderSets[i].VarName < summary.ProviderSets[j].VarName - }) - for _, file := range pkg.Syntax { - for _, decl := range file.Decls { - fn, ok := decl.(*ast.FuncDecl) - if !ok { - continue - } - buildCall, err := findInjectorBuild(pkg.TypesInfo, fn) - if err != nil || buildCall == nil { - continue - } - sig := pkg.TypesInfo.ObjectOf(fn.Name).Type().(*types.Signature) - ins, out, err := injectorFuncSignature(sig) - if err != nil { - continue - } - injectorArgs := &InjectorArgs{ - Name: fn.Name.Name, - Tuple: ins, - Pos: fn.Pos(), - } - set, errs := oc.processNewSet(pkg.TypesInfo, pkg.PkgPath, buildCall, injectorArgs, "") - if len(errs) > 0 { - continue - } - summary.Injectors = append(summary.Injectors, injectorSummary{ - Name: fn.Name.Name, - Inputs: summarizeTuple(ins), - Output: summaryTypeString(out.out), - Build: summarizeProviderSet(set), - }) - } - } - sort.Slice(summary.Injectors, func(i, j int) bool { - return summary.Injectors[i].Name < summary.Injectors[j].Name - }) - return summary, nil -} - -func summarizeProviderSet(pset *ProviderSet) providerSetSummary { - if pset == nil { - return providerSetSummary{} - } - summary := providerSetSummary{ - VarName: pset.VarName, - } - for _, provider := range pset.Providers { - summary.Providers = append(summary.Providers, summarizeProvider(provider)) - } - for _, imported := range pset.Imports { - summary.Imports = append(summary.Imports, providerSetRefSummary{ - PkgPath: imported.PkgPath, - VarName: imported.VarName, - }) - } - for _, binding := range pset.Bindings { - summary.Bindings = append(summary.Bindings, ifaceBindingSummary{ - Iface: summaryTypeString(binding.Iface), - Provided: summaryTypeString(binding.Provided), - }) - } - for _, value := range pset.Values { - summary.Values = append(summary.Values, summaryTypeString(value.Out)) - } - for _, field := range pset.Fields { - item := fieldSummary{ - Parent: summaryTypeString(field.Parent), - Name: field.Name, - Out: summarizeTypes(field.Out), - } - if field.Pkg != nil { - item.PkgPath = field.Pkg.Path() - } - summary.Fields = append(summary.Fields, item) - } - if pset.InjectorArgs != nil { - summary.InputTypes = summarizeTuple(pset.InjectorArgs.Tuple) - } - sort.Slice(summary.Providers, func(i, j int) bool { - return summary.Providers[i].PkgPath+"."+summary.Providers[i].Name < summary.Providers[j].PkgPath+"."+summary.Providers[j].Name - }) - sort.Slice(summary.Imports, func(i, j int) bool { - return summary.Imports[i].PkgPath+"."+summary.Imports[i].VarName < summary.Imports[j].PkgPath+"."+summary.Imports[j].VarName - }) - sort.Slice(summary.Bindings, func(i, j int) bool { - return summary.Bindings[i].Iface+":"+summary.Bindings[i].Provided < summary.Bindings[j].Iface+":"+summary.Bindings[j].Provided - }) - sort.Strings(summary.Values) - sort.Slice(summary.Fields, func(i, j int) bool { - return summary.Fields[i].Parent+"."+summary.Fields[i].Name < summary.Fields[j].Parent+"."+summary.Fields[j].Name - }) - sort.Strings(summary.InputTypes) - return summary -} - -func summarizeProvider(provider *Provider) providerSummary { - summary := providerSummary{ - Name: provider.Name, - Varargs: provider.Varargs, - IsStruct: provider.IsStruct, - HasCleanup: provider.HasCleanup, - HasErr: provider.HasErr, - Out: summarizeTypes(provider.Out), - } - if provider.Pkg != nil { - summary.PkgPath = provider.Pkg.Path() - } - for _, arg := range provider.Args { - summary.Args = append(summary.Args, providerInputSummary{ - Type: summaryTypeString(arg.Type), - FieldName: arg.FieldName, - }) - } - return summary -} - -func summarizeTuple(tuple *types.Tuple) []string { - if tuple == nil { - return nil - } - out := make([]string, 0, tuple.Len()) - for i := 0; i < tuple.Len(); i++ { - out = append(out, summaryTypeString(tuple.At(i).Type())) - } - return out -} - -func summarizeTypes(typesList []types.Type) []string { - out := make([]string, 0, len(typesList)) - for _, t := range typesList { - out = append(out, summaryTypeString(t)) - } - return out -} - -func summaryTypeString(t types.Type) string { - if t == nil { - return "" - } - return types.TypeString(t, func(pkg *types.Package) string { - if pkg == nil { - return "" - } - return pkg.Path() - }) -} - -func encodeIncrementalSummary(summary *packageSummary) ([]byte, error) { - if summary == nil { - return nil, fmt.Errorf("nil package summary") - } - var buf bytes.Buffer - enc := binarySummaryEncoder{buf: &buf} - enc.string(summary.Version) - enc.string(summary.WD) - enc.string(summary.Tags) - enc.string(summary.PkgPath) - enc.string(summary.ShapeHash) - enc.strings(summary.LocalImports) - enc.providerSets(summary.ProviderSets) - enc.u32(uint32(len(summary.Injectors))) - for _, injector := range summary.Injectors { - enc.string(injector.Name) - enc.strings(injector.Inputs) - enc.string(injector.Output) - enc.providerSet(injector.Build) - } - if enc.err != nil { - return nil, enc.err - } - return buf.Bytes(), nil -} - -func decodeIncrementalSummary(data []byte) (*packageSummary, error) { - dec := binarySummaryDecoder{r: bytes.NewReader(data)} - summary := &packageSummary{ - Version: dec.string(), - WD: dec.string(), - Tags: dec.string(), - PkgPath: dec.string(), - ShapeHash: dec.string(), - } - summary.LocalImports = dec.strings() - summary.ProviderSets = dec.providerSets() - for n := dec.u32(); n > 0; n-- { - summary.Injectors = append(summary.Injectors, injectorSummary{ - Name: dec.string(), - Inputs: dec.strings(), - Output: dec.string(), - Build: dec.providerSet(), - }) - } - if dec.err != nil { - return nil, dec.err - } - return summary, nil -} - -type binarySummaryEncoder struct { - buf *bytes.Buffer - err error -} - -func (e *binarySummaryEncoder) u32(v uint32) { - if e.err != nil { - return - } - e.err = binary.Write(e.buf, binary.LittleEndian, v) -} - -func (e *binarySummaryEncoder) string(s string) { - e.u32(uint32(len(s))) - if e.err != nil { - return - } - _, e.err = e.buf.WriteString(s) -} - -func (e *binarySummaryEncoder) bool(v bool) { - if e.err != nil { - return - } - var b byte - if v { - b = 1 - } - e.err = e.buf.WriteByte(b) -} - -func (e *binarySummaryEncoder) strings(values []string) { - e.u32(uint32(len(values))) - for _, v := range values { - e.string(v) - } -} - -func (e *binarySummaryEncoder) providerSets(values []providerSetSummary) { - e.u32(uint32(len(values))) - for _, value := range values { - e.providerSet(value) - } -} - -func (e *binarySummaryEncoder) providerSet(value providerSetSummary) { - e.string(value.VarName) - e.u32(uint32(len(value.Providers))) - for _, provider := range value.Providers { - e.string(provider.PkgPath) - e.string(provider.Name) - e.u32(uint32(len(provider.Args))) - for _, arg := range provider.Args { - e.string(arg.Type) - e.string(arg.FieldName) - } - e.strings(provider.Out) - e.bool(provider.Varargs) - e.bool(provider.IsStruct) - e.bool(provider.HasCleanup) - e.bool(provider.HasErr) - } - e.u32(uint32(len(value.Imports))) - for _, imported := range value.Imports { - e.string(imported.PkgPath) - e.string(imported.VarName) - } - e.u32(uint32(len(value.Bindings))) - for _, binding := range value.Bindings { - e.string(binding.Iface) - e.string(binding.Provided) - } - e.strings(value.Values) - e.u32(uint32(len(value.Fields))) - for _, field := range value.Fields { - e.string(field.PkgPath) - e.string(field.Parent) - e.string(field.Name) - e.strings(field.Out) - } - e.strings(value.InputTypes) -} - -type binarySummaryDecoder struct { - r *bytes.Reader - err error -} - -func (d *binarySummaryDecoder) u32() uint32 { - if d.err != nil { - return 0 - } - var v uint32 - d.err = binary.Read(d.r, binary.LittleEndian, &v) - return v -} - -func (d *binarySummaryDecoder) string() string { - n := d.u32() - if d.err != nil { - return "" - } - buf := make([]byte, n) - _, d.err = d.r.Read(buf) - return string(buf) -} - -func (d *binarySummaryDecoder) bool() bool { - if d.err != nil { - return false - } - b, err := d.r.ReadByte() - if err != nil { - d.err = err - return false - } - return b != 0 -} - -func (d *binarySummaryDecoder) strings() []string { - n := d.u32() - if d.err != nil { - return nil - } - out := make([]string, 0, n) - for i := uint32(0); i < n; i++ { - out = append(out, d.string()) - } - return out -} - -func (d *binarySummaryDecoder) providerSets() []providerSetSummary { - n := d.u32() - if d.err != nil { - return nil - } - out := make([]providerSetSummary, 0, n) - for i := uint32(0); i < n; i++ { - out = append(out, d.providerSet()) - } - return out -} - -func (d *binarySummaryDecoder) providerSet() providerSetSummary { - value := providerSetSummary{ - VarName: d.string(), - } - for n := d.u32(); n > 0; n-- { - provider := providerSummary{ - PkgPath: d.string(), - Name: d.string(), - } - for m := d.u32(); m > 0; m-- { - provider.Args = append(provider.Args, providerInputSummary{ - Type: d.string(), - FieldName: d.string(), - }) - } - provider.Out = d.strings() - provider.Varargs = d.bool() - provider.IsStruct = d.bool() - provider.HasCleanup = d.bool() - provider.HasErr = d.bool() - value.Providers = append(value.Providers, provider) - } - for n := d.u32(); n > 0; n-- { - value.Imports = append(value.Imports, providerSetRefSummary{ - PkgPath: d.string(), - VarName: d.string(), - }) - } - for n := d.u32(); n > 0; n-- { - value.Bindings = append(value.Bindings, ifaceBindingSummary{ - Iface: d.string(), - Provided: d.string(), - }) - } - value.Values = d.strings() - for n := d.u32(); n > 0; n-- { - value.Fields = append(value.Fields, fieldSummary{ - PkgPath: d.string(), - Parent: d.string(), - Name: d.string(), - Out: d.strings(), - }) - } - value.InputTypes = d.strings() - return value -} diff --git a/internal/wire/incremental_summary_test.go b/internal/wire/incremental_summary_test.go deleted file mode 100644 index ae85651..0000000 --- a/internal/wire/incremental_summary_test.go +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestIncrementalSummaryEncodeDecodeRoundTrip(t *testing.T) { - summary := &packageSummary{ - Version: incrementalSummaryVersion, - WD: "/tmp/app", - Tags: "dev", - PkgPath: "example.com/app/dep", - ShapeHash: "abc123", - LocalImports: []string{"example.com/app/shared"}, - ProviderSets: []providerSetSummary{{ - VarName: "Set", - Providers: []providerSummary{{ - PkgPath: "example.com/app/dep", - Name: "NewThing", - Args: []providerInputSummary{{Type: "string"}}, - Out: []string{"*example.com/app/dep.Thing"}, - HasCleanup: true, - }}, - Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, - Bindings: []ifaceBindingSummary{{Iface: "error", Provided: "*example.com/app/dep.Thing"}}, - Values: []string{"string"}, - Fields: []fieldSummary{{PkgPath: "example.com/app/dep", Parent: "example.com/app/dep.Config", Name: "Name", Out: []string{"string"}}}, - InputTypes: []string{"context.Context"}, - }}, - Injectors: []injectorSummary{{ - Name: "Init", - Inputs: []string{"context.Context"}, - Output: "*example.com/app/dep.Thing", - Build: providerSetSummary{ - Imports: []providerSetRefSummary{{PkgPath: "example.com/app/shared", VarName: "SharedSet"}}, - }, - }}, - } - data, err := encodeIncrementalSummary(summary) - if err != nil { - t.Fatalf("encodeIncrementalSummary: %v", err) - } - got, err := decodeIncrementalSummary(data) - if err != nil { - t.Fatalf("decodeIncrementalSummary: %v", err) - } - if got.Version != summary.Version || got.PkgPath != summary.PkgPath || got.ShapeHash != summary.ShapeHash { - t.Fatalf("decoded summary mismatch: %+v", got) - } - if len(got.ProviderSets) != 1 || got.ProviderSets[0].VarName != "Set" { - t.Fatalf("decoded provider sets mismatch: %+v", got.ProviderSets) - } - if len(got.Injectors) != 1 || got.Injectors[0].Name != "Init" { - t.Fatalf("decoded injectors mismatch: %+v", got.Injectors) - } -} - -func TestBuildPackageSummary(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo { return &Foo{Message: msg} }", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - oc := newObjectCache(pkgs, loader) - loadedDep, errs := oc.ensurePackage("example.com/app/dep") - if len(errs) > 0 { - t.Fatalf("ensurePackage returned errors: %v", errs) - } - summary, err := buildPackageSummary(loader, oc, loadedDep) - if err != nil { - t.Fatalf("buildPackageSummary: %v", err) - } - if summary.PkgPath != "example.com/app/dep" { - t.Fatalf("summary pkg path = %q", summary.PkgPath) - } - if len(summary.ProviderSets) != 1 || summary.ProviderSets[0].VarName != "Set" { - t.Fatalf("unexpected provider sets: %+v", summary.ProviderSets) - } - if len(summary.ProviderSets[0].Providers) != 2 { - t.Fatalf("unexpected providers: %+v", summary.ProviderSets[0].Providers) - } - loadedApp, errs := oc.ensurePackage("example.com/app/app") - if len(errs) > 0 { - t.Fatalf("ensurePackage app returned errors: %v", errs) - } - appSummary, err := buildPackageSummary(loader, oc, loadedApp) - if err != nil { - t.Fatalf("buildPackageSummary app: %v", err) - } - if len(appSummary.Injectors) != 1 || appSummary.Injectors[0].Name != "Init" { - t.Fatalf("unexpected injectors: %+v", appSummary.Injectors) - } - if len(appSummary.Injectors[0].Build.Imports) != 1 || appSummary.Injectors[0].Build.Imports[0].PkgPath != "example.com/app/dep" { - t.Fatalf("unexpected injector imports: %+v", appSummary.Injectors[0].Build.Imports) - } -} - -func TestCollectIncrementalPackageSummariesUsesCacheForUnchanged(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo { return &Foo{Message: msg} }", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate result: %+v", gens) - } - pkgs, loader, errs := load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors while seeding summaries: %v", errs) - } - if _, errs := newObjectCache(pkgs, loader).ensurePackage("example.com/app/app"); len(errs) > 0 { - t.Fatalf("ensurePackage returned errors while seeding summaries: %v", errs) - } - writeIncrementalPackageSummaries(loader, pkgs) - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct{ Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo { return &Foo{Message: msg, Count: count} }", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - pkgs, loader, errs = load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - snapshot := collectIncrementalPackageSummaries(loader, pkgs) - if snapshot == nil { - t.Fatal("collectIncrementalPackageSummaries returned nil") - } - if _, ok := snapshot.Changed["example.com/app/dep"]; !ok { - t.Fatalf("expected changed dep summary, got %+v", snapshot.Changed) - } - if _, ok := snapshot.Unchanged["example.com/app/app"]; !ok { - t.Fatalf("expected unchanged app summary from cache, got %+v", snapshot.Unchanged) - } - if len(snapshot.Unchanged["example.com/app/app"].Injectors) != 1 { - t.Fatalf("unexpected cached app summary: %+v", snapshot.Unchanged["example.com/app/app"]) - } - if len(snapshot.Changed["example.com/app/dep"].ProviderSets) != 1 { - t.Fatalf("unexpected changed dep summary: %+v", snapshot.Changed["example.com/app/dep"]) - } -} diff --git a/internal/wire/incremental_test.go b/internal/wire/incremental_test.go deleted file mode 100644 index a531123..0000000 --- a/internal/wire/incremental_test.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "testing" -) - -func TestIncrementalEnabledDefaultOff(t *testing.T) { - if IncrementalEnabled(context.Background(), nil) { - t.Fatal("IncrementalEnabled should default to false") - } -} - -func TestIncrementalEnabledFromEnv(t *testing.T) { - env := []string{ - "FOO=bar", - IncrementalEnvVar + "=true", - } - if !IncrementalEnabled(context.Background(), env) { - t.Fatal("IncrementalEnabled should read the environment variable") - } -} - -func TestIncrementalEnabledUsesLastEnvValue(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=false", - IncrementalEnvVar + "=true", - } - if !IncrementalEnabled(context.Background(), env) { - t.Fatal("IncrementalEnabled should use the last matching env value") - } -} - -func TestIncrementalEnabledContextOverridesEnv(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=false", - } - ctx := WithIncremental(context.Background(), true) - if !IncrementalEnabled(ctx, env) { - t.Fatal("context override should take precedence over env") - } -} - -func TestIncrementalEnabledInvalidEnvFallsBackFalse(t *testing.T) { - env := []string{ - IncrementalEnvVar + "=maybe", - } - if IncrementalEnabled(context.Background(), env) { - t.Fatal("invalid env value should not enable incremental mode") - } -} diff --git a/internal/wire/load_debug.go b/internal/wire/load_debug.go index fd8c4d7..d3d5fc1 100644 --- a/internal/wire/load_debug.go +++ b/internal/wire/load_debug.go @@ -124,7 +124,7 @@ func logLoadDebug(ctx context.Context, scope string, mode packages.LoadMode, sub } if parseStats != nil { snap := parseStats.snapshot() - debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.total=%s", + debugf(ctx, "load.debug scope=%s parse.calls=%d parse.primary=%d parse.deps=%d parse.cache_hits=%d parse.cache_misses=%d parse.errors=%d parse.cumulative=%s", scope, snap.calls, snap.primaryCalls, @@ -190,6 +190,23 @@ func summarizeLoadScope(wd string, pkgs []*packages.Package) loadScopeStats { return stats } +func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { + all := make(map[string]*packages.Package) + stack := append([]*packages.Package(nil), pkgs...) + for len(stack) > 0 { + p := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if p == nil || all[p.PkgPath] != nil { + continue + } + all[p.PkgPath] = p + for _, imp := range p.Imports { + stack = append(stack, imp) + } + } + return all +} + func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { if moduleRoot == "" || pkg == nil { return "unknown" @@ -210,8 +227,8 @@ func classifyPackageLocation(moduleRoot string, pkg *packages.Package) string { } func isWithinRoot(root, name string) bool { - cleanRoot := filepath.Clean(root) - cleanName := filepath.Clean(name) + cleanRoot := canonicalPath(root) + cleanName := canonicalPath(name) if cleanName == cleanRoot { return true } @@ -222,6 +239,14 @@ func isWithinRoot(root, name string) bool { return rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) } +func canonicalPath(path string) string { + clean := filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(clean); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return clean +} + func topPackageMetrics(metrics []packageMetric) []string { sort.Slice(metrics, func(i, j int) bool { if metrics[i].count == metrics[j].count { diff --git a/internal/wire/loader_test.go b/internal/wire/loader_test.go deleted file mode 100644 index 37e27d9..0000000 --- a/internal/wire/loader_test.go +++ /dev/null @@ -1,2596 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "os" - "path/filepath" - "strings" - "testing" - "time" -) - -func TestLoadAndGenerateModule(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "app.go"), strings.Join([]string{ - "package app", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "noop", "noop.go"), strings.Join([]string{ - "package noop", - "", - "type Thing struct{}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } - - gens, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 { - t.Fatalf("Generate returned %d results, want 1", len(gens)) - } - if len(gens[0].Errs) > 0 { - t.Fatalf("Generate result had errors: %v", gens[0].Errs) - } - if len(gens[0].Content) == 0 { - t.Fatal("Generate returned empty output for wire package") - } - if gens[0].OutputPath == "" { - t.Fatal("Generate returned empty output path") - } - - noops, errs := Generate(ctx, root, env, []string{"./noop"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate noop returned errors: %v", errs) - } - if len(noops) != 1 { - t.Fatalf("Generate noop returned %d results, want 1", len(noops)) - } - if len(noops[0].Errs) > 0 { - t.Fatalf("Generate noop result had errors: %v", noops[0].Errs) - } - if noops[0].OutputPath == "" { - t.Fatal("Generate noop returned empty output path") - } - if len(noops[0].Content) != 0 { - t.Fatal("Generate noop returned unexpected output") - } -} - -func TestLoadAndGenerateModuleIncrementalMatches(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - - info, errs := Load(context.Background(), root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil || len(info.Injectors) != 1 { - t.Fatalf("Load returned unexpected info: %+v errs=%v", info, errs) - } - - incrementalCtx := WithIncremental(context.Background(), true) - incrementalInfo, errs := Load(incrementalCtx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("incremental Load returned errors: %v", errs) - } - if incrementalInfo == nil || len(incrementalInfo.Injectors) != 1 { - t.Fatalf("incremental Load returned unexpected info: %+v errs=%v", incrementalInfo, errs) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - incrementalGens, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(incrementalGens) != 1 { - t.Fatalf("unexpected result counts: normal=%d incremental=%d", len(normalGens), len(incrementalGens)) - } - if len(normalGens[0].Errs) > 0 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected generate errors: normal=%v incremental=%v", normalGens[0].Errs, incrementalGens[0].Errs) - } - if normalGens[0].OutputPath != incrementalGens[0].OutputPath { - t.Fatalf("output paths differ: normal=%q incremental=%q", normalGens[0].OutputPath, incrementalGens[0].OutputPath) - } - if string(normalGens[0].Content) != string(incrementalGens[0].Content) { - t.Fatalf("generated content differs between normal and incremental modes") - } -} - -func TestGenerateIncrementalBodyOnlyChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "wire_gen.go"), strings.Join([]string{ - "//go:build !wireinject", - "", - "package app", - "", - "func generated() {}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "app", "app_test.go"), strings.Join([]string{ - "package app", - "", - "func testOnly() {}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - var firstLabels []string - firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { - firstLabels = append(firstLabels, label) - }) - first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - if !containsLabel(firstLabels, "load.packages.lazy.load") { - t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) - } - - if err := os.WriteFile(depFile, []byte(strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"b\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected second Generate to reuse preload manifest after body-only change, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected body-only change to reuse identical generated output") - } -} - -func TestGenerateIncrementalTouchedValidationCacheReusesSuccessfulValidation(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeBodyVariant := func(message string) { - t.Helper() - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"" + message + "\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - } - writeBodyVariant("a") - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeBodyVariant("b") - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected first body-only variant change to avoid generate.load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "incremental.preload_manifest.validate_touched_cache_hit") { - t.Fatalf("did not expect first body-only variant change to hit touched validation cache, labels=%v", secondLabels) - } - - writeBodyVariant("a") - third, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - - writeBodyVariant("b") - - var fourthLabels []string - fourthCtx := WithTiming(ctx, func(label string, _ time.Duration) { - fourthLabels = append(fourthLabels, label) - }) - fourth, errs := Generate(fourthCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("fourth Generate returned errors: %v", errs) - } - if len(fourth) != 1 || len(fourth[0].Errs) > 0 { - t.Fatalf("unexpected fourth Generate result: %+v", fourth) - } - if containsLabel(fourthLabels, "generate.load") { - t.Fatalf("expected repeated body-only variant change to avoid generate.load, labels=%v", fourthLabels) - } - if !containsLabel(fourthLabels, "incremental.preload_manifest.validate_touched_cache_hit") { - t.Fatalf("expected repeated body-only variant change to hit touched validation cache, labels=%v", fourthLabels) - } - if string(first[0].Content) != string(fourth[0].Content) { - t.Fatal("expected repeated body-only variant change to reuse identical generated output") - } -} - -func TestGenerateIncrementalConstValueChangeUsesPreloadManifestAndReusesOutput(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected const-value change to reuse preload manifest, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected const-value change to reuse identical generated output") - } -} - -func TestGenerateIncrementalBodyOnlyInvalidChangeDoesNotReusePreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn missing", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid body-only change to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid body-only change to return errors") - } - if !containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected invalid body-only change to bypass preload manifest and load packages, labels=%v", secondLabels) - } - if got := errs[0].Error(); !strings.Contains(got, "undefined: missing") { - t.Fatalf("expected load/type-check error from invalid body-only change, got %q", got) - } -} - -func TestGenerateIncrementalScenarioMatrix(t *testing.T) { - t.Parallel() - - type scenarioExpectation struct { - mode string - wantErr bool - wantSameOutput bool - } - - scenarios := []struct { - name string - apply func(t *testing.T, fx incrementalScenarioFixture) - want scenarioExpectation - }{ - { - name: "comment_only_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "// SQLText controls SQL highlighting in log output.", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "whitespace_only_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "", - "func New(msg string) *Foo {", - "", - "\treturn &Foo{Message: helper(msg)}", - "", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "function_body_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string {", - "\treturn helper(SQLText)", - "}", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "method_body_change_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func (f Foo) Summary() string {", - "\treturn helper(f.Message)", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "const_value_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"blue\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "var_initializer_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 2", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "add_top_level_helper_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func NewTag() string { return \"tag\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "import_only_implementation_change_reuses_preload", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "import \"fmt\"", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return fmt.Sprint(msg) }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "preload", wantSameOutput: true}, - }, - { - name: "signature_change_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 7", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func NewCount() int { return defaultCount }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: count}", - "}", - "", - }, "\n")) - writeFile(t, fx.wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: false}, - }, - { - name: "struct_field_addition_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct {", - "\tMessage string", - "\tCount int", - "}", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg), Count: defaultCount}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "interface_method_addition_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Fooer interface {", - "\tMessage() string", - "\tCount() int", - "}", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "local_fastpath", wantSameOutput: true}, - }, - { - name: "new_source_file_uses_local_fastpath", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.extraFile, strings.Join([]string{ - "package dep", - "", - "func NewTag() string { return \"tag\" }", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "fast", wantSameOutput: true}, - }, - { - name: "invalid_body_change_falls_back_and_errors", - apply: func(t *testing.T, fx incrementalScenarioFixture) { - writeFile(t, fx.depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return missing }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - }, - want: scenarioExpectation{mode: "generate_load", wantErr: true}, - }, - } - - for _, scenario := range scenarios { - scenario := scenario - t.Run(scenario.name, func(t *testing.T) { - fx := newIncrementalScenarioFixture(t) - - first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("baseline Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected baseline Generate result: %+v", first) - } - - scenario.apply(t, fx) - - var labels []string - timedCtx := WithTiming(fx.ctx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - second, errs := Generate(timedCtx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - - if scenario.want.wantErr { - if len(errs) == 0 { - t.Fatal("expected Generate to return errors") - } - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - } else { - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate result: %+v", second) - } - } - - switch scenario.want.mode { - case "preload": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected preload reuse without generate.load, labels=%v", labels) - } - case "fast": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected fast incremental path without generate.load, labels=%v", labels) - } - case "local_fastpath": - if containsLabel(labels, "generate.load") { - t.Fatalf("expected local fast path without generate.load, labels=%v", labels) - } - if containsLabel(labels, "load.packages.lazy.load") { - t.Fatalf("expected local fast path to skip lazy load, labels=%v", labels) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected local fast path load, labels=%v", labels) - } - case "generate_load": - if !containsLabel(labels, "generate.load") { - t.Fatalf("expected generate.load fallback, labels=%v", labels) - } - default: - t.Fatalf("unknown expected mode %q", scenario.want.mode) - } - - if scenario.want.wantErr { - return - } - - normal, errs := Generate(context.Background(), fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors after edit: %v", errs) - } - if len(normal) != 1 || len(normal[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate result after edit: %+v", normal) - } - if second[0].OutputPath != normal[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", second[0].OutputPath, normal[0].OutputPath) - } - if string(second[0].Content) != string(normal[0].Content) { - t.Fatalf("incremental output differs from normal output after %s", scenario.name) - } - if scenario.want.wantSameOutput && string(first[0].Content) != string(second[0].Content) { - t.Fatalf("expected generated output to stay unchanged for %s", scenario.name) - } - if !scenario.want.wantSameOutput && string(first[0].Content) == string(second[0].Content) { - t.Fatalf("expected generated output to change for %s", scenario.name) - } - }) - } -} - -func TestGenerateIncrementalShapeChangeFallsBackToLazyLoad(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - var firstLabels []string - firstCtx := WithTiming(ctx, func(label string, _ time.Duration) { - firstLabels = append(firstLabels, label) - }) - first, errs := Generate(firstCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - if !containsLabel(firstLabels, "load.packages.lazy.load") { - t.Fatalf("expected first incremental generate to perform lazy load, labels=%v", firstLabels) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected shape-changing incremental run to skip package load via local fast path, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected shape-changing incremental run to skip lazy load via local fast path, labels=%v", secondLabels) - } - if !containsLabel(secondLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected shape-changing incremental run to use local fast path, labels=%v", secondLabels) - } - if string(first[0].Content) == string(second[0].Content) { - t.Fatal("expected shape-changing edit to regenerate different output") - } -} - -func TestGenerateIncrementalRepeatedShapeStateHitsPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected repeated shape state to hit preload manifest before package load, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected repeated shape state to skip lazy load, labels=%v", secondLabels) - } - if string(first[0].Content) != string(second[0].Content) { - t.Fatal("expected repeated shape state to reuse identical generated output") - } -} - -func TestGenerateIncrementalShapeChangeThenRepeatHitsPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ - "package extra", - "", - "type Marker struct{}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - if containsLabel(secondLabels, "generate.load") { - t.Fatalf("expected shape-changing Generate to skip package load via local fast path, labels=%v", secondLabels) - } - if containsLabel(secondLabels, "load.packages.lazy.load") { - t.Fatalf("expected shape-changing Generate to skip lazy load via local fast path, labels=%v", secondLabels) - } - if !containsLabel(secondLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected shape-changing Generate to use local fast path, labels=%v", secondLabels) - } - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - if containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected repeated shape-changing state to hit preload manifest before package load, labels=%v", thirdLabels) - } - if containsLabel(thirdLabels, "load.packages.lazy.load") { - t.Fatalf("expected repeated shape-changing state to skip lazy load, labels=%v", thirdLabels) - } - if string(second[0].Content) != string(third[0].Content) { - t.Fatal("expected repeated shape-changing state to reuse identical generated output") - } -} - -func TestGenerateIncrementalShapeChangeMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeIncrementalBenchmarkModule(t, repoRoot, root) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental shape-change Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected incremental shape-change Generate to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate results: %+v", normalGens) - } - if incrementalGens[0].OutputPath != normalGens[0].OutputPath { - t.Fatalf("output paths differ: incremental=%q normal=%q", incrementalGens[0].OutputPath, normalGens[0].OutputPath) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("shape-changing incremental output differs from normal Generate output") - } -} - -func TestGenerateIncrementalColdBootstrapStillSeedsFastPath(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeLargeBenchmarkModule(t, repoRoot, root, 24) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - if _, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("cold bootstrap Generate returned errors: %v", errs) - } - - mutateLargeBenchmarkModule(t, root, 12) - - var labels []string - timedCtx := WithTiming(ctx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("shape-change Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected cold bootstrap to seed fast path, labels=%v", labels) - } -} - -func TestLoadLocalPackagesForFastPathImportsUnchangedLocalDependencyFromLocalExport(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depPkgPath := "example.com/app/dep" - depExportPath := mustLocalExportPath(t, root, env, depPkgPath) - if _, err := os.Stat(depExportPath); err != nil { - t.Fatalf("expected local export artifact at %s: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - preloadState, ok := prepareIncrementalPreloadState(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if !ok || preloadState == nil || preloadState.manifest == nil { - t.Fatal("expected preload state after baseline incremental generate") - } - loaded, err := loadLocalPackagesForFastPath(context.Background(), root, "", "example.com/app/app", []string{"example.com/app/router"}, preloadState.currentLocal, preloadState.manifest.ExternalPkgs) - if err != nil { - t.Fatalf("loadLocalPackagesForFastPath returned error: %v", err) - } - if _, ok := loaded.loader.localExports[depPkgPath]; !ok { - t.Fatalf("expected %s to be a local export candidate", depPkgPath) - } - if _, ok := loaded.loader.sourcePkgs[depPkgPath]; ok { - t.Fatalf("did not expect %s to be source-loaded", depPkgPath) - } - typesPkg, err := loaded.loader.importPackage(depPkgPath) - if err != nil { - t.Fatalf("importPackage(%s) returned error: %v", depPkgPath, err) - } - if typesPkg == nil || !typesPkg.Complete() { - t.Fatalf("expected complete imported package for %s, got %#v", depPkgPath, typesPkg) - } - if loaded.loader.pkgs[depPkgPath] != nil { - t.Fatalf("expected %s to avoid source loading when local export artifact is present", depPkgPath) - } -} - -func TestGenerateIncrementalMissingLocalExportFallsBackSafely(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if err := os.Remove(depExportPath); err != nil { - t.Fatalf("Remove(%s) failed: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - var labels []string - timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected missing local export to stay on local fast path, labels=%v", labels) - } - refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if _, err := os.Stat(refreshedExportPath); err != nil { - t.Fatalf("expected local export artifact to be refreshed at %s: %v", refreshedExportPath, err) - } -} - -func TestGenerateIncrementalCorruptedLocalExportFallsBackSafely(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - writeDepRouterModule(t, root, repoRoot) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - depExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - if err := os.WriteFile(depExportPath, []byte("not-a-valid-export"), 0644); err != nil { - t.Fatalf("WriteFile(%s) failed: %v", depExportPath, err) - } - - mutateRouterModule(t, root) - - var labels []string - timedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - labels = append(labels, label) - }) - gens, errs := Generate(timedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("Generate returned errors: %v", errs) - } - if len(gens) != 1 || len(gens[0].Errs) > 0 { - t.Fatalf("unexpected Generate results: %+v", gens) - } - if !containsLabel(labels, "incremental.local_fastpath.load") { - t.Fatalf("expected corrupted local export to stay on local fast path, labels=%v", labels) - } - refreshedExportPath := mustLocalExportPath(t, root, env, "example.com/app/dep") - data, err := os.ReadFile(refreshedExportPath) - if err != nil { - t.Fatalf("ReadFile(%s) failed: %v", refreshedExportPath, err) - } - if string(data) == "not-a-valid-export" { - t.Fatalf("expected corrupted local export artifact to be refreshed at %s", refreshedExportPath) - } -} - -func TestGenerateIncrementalShapeChangeWithUnchangedDependentPackageMatchesNormalGenerate(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"example.com/app/router\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *router.Routes {", - "\twire.Build(dep.Set, router.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewController(msg string) *Controller {", - "\treturn &Controller{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewController)", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct { Controller *dep.Controller }", - "", - "func ProvideRoutes(controller *dep.Controller) *Routes {", - "\treturn &Routes{Controller: controller}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(ProvideRoutes)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - incrementalCtx := WithIncremental(context.Background(), true) - - if _, errs := Generate(incrementalCtx, root, env, []string{"./app"}, &GenerateOptions{}); len(errs) > 0 { - t.Fatalf("baseline incremental Generate returned errors: %v", errs) - } - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string; Count int }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewCount() int { return 7 }", - "", - "func NewController(msg string, count int) *Controller {", - "\treturn &Controller{Message: msg, Count: count}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewCount, NewController)", - "", - }, "\n")) - - var incrementalLabels []string - incrementalTimedCtx := WithTiming(incrementalCtx, func(label string, _ time.Duration) { - incrementalLabels = append(incrementalLabels, label) - }) - incrementalGens, errs := Generate(incrementalTimedCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("incremental Generate returned errors: %v", errs) - } - if len(incrementalGens) != 1 || len(incrementalGens[0].Errs) > 0 { - t.Fatalf("unexpected incremental Generate results: %+v", incrementalGens) - } - if !containsLabel(incrementalLabels, "incremental.local_fastpath.load") { - t.Fatalf("expected incremental Generate to use local fast path, labels=%v", incrementalLabels) - } - - normalGens, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normalGens) != 1 || len(normalGens[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate results: %+v", normalGens) - } - if string(incrementalGens[0].Content) != string(normalGens[0].Content) { - t.Fatal("incremental output differs from normal Generate output when unchanged package depends on changed package") - } -} - -func TestGenerateIncrementalInvalidShapeChangeDoesNotReuseManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "import \"example.com/app/extra\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - var secondLabels []string - secondCtx := WithTiming(ctx, func(label string, _ time.Duration) { - secondLabels = append(secondLabels, label) - }) - second, errs := Generate(secondCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid incremental generate to return errors") - } - if got := errs[0].Error(); !strings.Contains(got, "type-check failed for example.com/app/app") { - t.Fatalf("expected fast-path type-check error, got %q", got) - } - if _, ok := readIncrementalManifest(incrementalManifestSelectorKey(root, env, []string{"./app"}, &GenerateOptions{})); ok { - t.Fatal("expected invalid incremental generate to invalidate selector manifest") - } -} - -func TestGenerateIncrementalRecoversAfterInvalidShapeChange(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "import \"example.com/app/extra\"", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n")) - - second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(second) != 0 { - t.Fatalf("expected invalid incremental generate to stop before generation, got %+v", second) - } - if len(errs) == 0 { - t.Fatal("expected invalid incremental generate to return errors") - } - clearIncrementalSessions() - - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n")) - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n")) - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("recovery incremental Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected recovery incremental Generate result: %+v", third) - } - - normal, errs := Generate(context.Background(), root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("normal Generate returned errors: %v", errs) - } - if len(normal) != 1 || len(normal[0].Errs) > 0 { - t.Fatalf("unexpected normal Generate result: %+v", normal) - } - if string(third[0].Content) != string(normal[0].Content) { - t.Fatal("incremental output differs from normal Generate output after recovering from invalid shape change") - } - if !containsLabel(thirdLabels, "incremental.local_fastpath.load") && !containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected recovery run to rebuild through local fast path or normal load, labels=%v", thirdLabels) - } -} - -func TestGenerateIncrementalToggleBackToKnownShapeHitsArchivedPreloadManifest(t *testing.T) { - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - wireFile := filepath.Join(root, "dep", "wire.go") - - oldDep := strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: msg}", - "}", - "", - }, "\n") - newDep := strings.Join([]string{ - "package dep", - "", - "type Foo struct { Message string; Count int }", - "", - "func NewMessage() string { return \"a\" }", - "", - "func NewCount() int { return 7 }", - "", - "func New(msg string, count int) *Foo {", - "\treturn &Foo{Message: msg, Count: count}", - "}", - "", - }, "\n") - oldWire := strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n") - newWire := strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, NewCount, New)", - "", - }, "\n") - - writeFile(t, depFile, oldDep) - writeFile(t, wireFile, oldWire) - - env := append(os.Environ(), "GOWORK=off") - ctx := WithIncremental(context.Background(), true) - - first, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("first Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected first Generate result: %+v", first) - } - - writeFile(t, depFile, newDep) - writeFile(t, wireFile, newWire) - second, errs := Generate(ctx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("second Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected second Generate result: %+v", second) - } - - writeFile(t, depFile, oldDep) - writeFile(t, wireFile, oldWire) - - var thirdLabels []string - thirdCtx := WithTiming(ctx, func(label string, _ time.Duration) { - thirdLabels = append(thirdLabels, label) - }) - third, errs := Generate(thirdCtx, root, env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("third Generate returned errors: %v", errs) - } - if len(third) != 1 || len(third[0].Errs) > 0 { - t.Fatalf("unexpected third Generate result: %+v", third) - } - if containsLabel(thirdLabels, "generate.load") { - t.Fatalf("expected toggled-back shape state to hit archived preload manifest before package load, labels=%v", thirdLabels) - } - if containsLabel(thirdLabels, "load.packages.lazy.load") { - t.Fatalf("expected toggled-back shape state to skip lazy load, labels=%v", thirdLabels) - } - if string(first[0].Content) != string(third[0].Content) { - t.Fatal("expected toggled-back shape state to reuse archived generated output") - } -} - -func TestGenerateIncrementalPreloadHitRefreshesMissingContentHashes(t *testing.T) { - fx := newIncrementalScenarioFixture(t) - - first, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("baseline Generate returned errors: %v", errs) - } - if len(first) != 1 || len(first[0].Errs) > 0 { - t.Fatalf("unexpected baseline Generate result: %+v", first) - } - - selectorKey := incrementalManifestSelectorKey(fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - manifest, ok := readIncrementalManifest(selectorKey) - if !ok { - t.Fatal("expected incremental manifest after baseline generate") - } - if len(manifest.LocalPackages) == 0 { - t.Fatal("expected local packages in incremental manifest") - } - - stale := *manifest - stale.LocalPackages = append([]packageFingerprint(nil), manifest.LocalPackages...) - for i := range stale.LocalPackages { - stale.LocalPackages[i].ContentHash = "" - stale.LocalPackages[i].Dirs = nil - } - writeIncrementalManifestFile(selectorKey, &stale) - writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, stale.LocalPackages), &stale) - - second, errs := Generate(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if len(errs) > 0 { - t.Fatalf("refresh Generate returned errors: %v", errs) - } - if len(second) != 1 || len(second[0].Errs) > 0 { - t.Fatalf("unexpected refresh Generate result: %+v", second) - } - - preloadState, ok := prepareIncrementalPreloadState(fx.ctx, fx.root, fx.env, []string{"./app"}, &GenerateOptions{}) - if !ok { - t.Fatal("expected preload state after manifest refresh") - } - if !preloadState.valid { - t.Fatalf("expected refreshed preload state to be valid, reason=%s", preloadState.reason) - } - if len(preloadState.touched) != 0 { - t.Fatalf("expected refreshed preload state to have no touched packages, got %v", preloadState.touched) - } -} - -func containsLabel(labels []string, want string) bool { - for _, label := range labels { - if label == want { - return true - } - } - return false -} - -type incrementalScenarioFixture struct { - root string - env []string - ctx context.Context - depFile string - wireFile string - extraFile string -} - -func newIncrementalScenarioFixture(t *testing.T) incrementalScenarioFixture { - t.Helper() - - lockCacheHooks(t) - state := saveCacheHooks() - t.Cleanup(func() { restoreCacheHooks(state) }) - - cacheRoot := t.TempDir() - osTempDir = func() string { return cacheRoot } - - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.NewSet)", - "\treturn nil", - "}", - "", - }, "\n")) - - depFile := filepath.Join(root, "dep", "dep.go") - writeFile(t, depFile, strings.Join([]string{ - "package dep", - "", - "const SQLText = \"green\"", - "", - "var defaultCount = 1", - "", - "type Foo struct { Message string }", - "", - "func NewMessage() string { return SQLText }", - "", - "func helper(msg string) string { return msg }", - "", - "func New(msg string) *Foo {", - "\treturn &Foo{Message: helper(msg)}", - "}", - "", - }, "\n")) - - wireFile := filepath.Join(root, "dep", "wire.go") - writeFile(t, wireFile, strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var NewSet = wire.NewSet(NewMessage, New)", - "", - }, "\n")) - - return incrementalScenarioFixture{ - root: root, - env: append(os.Environ(), "GOWORK=off"), - ctx: WithIncremental(context.Background(), true), - depFile: depFile, - wireFile: wireFile, - extraFile: filepath.Join(root, "dep", "extra.go"), - } -} - -func mustRepoRoot(t *testing.T) string { - t.Helper() - wd, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd failed: %v", err) - } - repoRoot := filepath.Clean(filepath.Join(wd, "..", "..")) - if _, err := os.Stat(filepath.Join(repoRoot, "go.mod")); err != nil { - t.Fatalf("repo root not found at %s: %v", repoRoot, err) - } - return repoRoot -} - -func writeDepRouterModule(t *testing.T, root string, repoRoot string) { - t.Helper() - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"example.com/app/router\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *router.Routes {", - "\twire.Build(dep.Set, router.Set)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Controller struct { Message string }", - "", - "func NewMessage() string { return \"ok\" }", - "", - "func NewController(msg string) *Controller {", - "\treturn &Controller{Message: msg}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "wire.go"), strings.Join([]string{ - "package dep", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewMessage, NewController)", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct { Controller *dep.Controller }", - "", - "func ProvideRoutes(controller *dep.Controller) *Routes {", - "\treturn &Routes{Controller: controller}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(ProvideRoutes)", - "", - }, "\n")) -} - -func mutateRouterModule(t *testing.T, root string) { - t.Helper() - writeFile(t, filepath.Join(root, "router", "router.go"), strings.Join([]string{ - "package router", - "", - "import \"example.com/app/dep\"", - "", - "type Routes struct {", - "\tController *dep.Controller", - "\tVersion int", - "}", - "", - "func NewVersion() int {", - "\treturn 2", - "}", - "", - "func ProvideRoutes(controller *dep.Controller, version int) *Routes {", - "\treturn &Routes{Controller: controller, Version: version}", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "router", "wire.go"), strings.Join([]string{ - "package router", - "", - "import \"github.com/goforj/wire\"", - "", - "var Set = wire.NewSet(NewVersion, ProvideRoutes)", - "", - }, "\n")) -} - -func mustLocalExportPath(t *testing.T, root string, env []string, pkgPath string) string { - t.Helper() - pkgs, loader, errs := load(context.Background(), root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("load returned errors: %v", errs) - } - if loader == nil { - t.Fatal("load returned nil loader") - } - if _, errs := loader.load("example.com/app/app"); len(errs) > 0 { - t.Fatalf("lazy load returned errors: %v", errs) - } - snapshot := buildIncrementalManifestSnapshotFromPackages(root, "", incrementalManifestPackages(pkgs, loader)) - if snapshot == nil || snapshot.fingerprints[pkgPath] == nil { - t.Fatalf("missing fingerprint for %s", pkgPath) - } - path := localExportPathForFingerprint(root, "", snapshot.fingerprints[pkgPath]) - if path == "" { - t.Fatalf("missing local export path for %s", pkgPath) - } - return path -} - -func writeFile(t *testing.T, path string, content string) { - t.Helper() - if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { - t.Fatalf("MkdirAll failed: %v", err) - } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - t.Fatalf("WriteFile failed: %v", err) - } -} diff --git a/internal/wire/loader_timing_bridge.go b/internal/wire/loader_timing_bridge.go new file mode 100644 index 0000000..2de0245 --- /dev/null +++ b/internal/wire/loader_timing_bridge.go @@ -0,0 +1,17 @@ +package wire + +import ( + "context" + "time" + + "github.com/goforj/wire/internal/loader" +) + +func withLoaderTiming(ctx context.Context) context.Context { + if t := timing(ctx); t != nil { + return loader.WithTiming(ctx, func(label string, d time.Duration) { + t(label, d) + }) + } + return ctx +} diff --git a/internal/wire/cache_hooks.go b/internal/wire/loader_validation.go similarity index 50% rename from internal/wire/cache_hooks.go rename to internal/wire/loader_validation.go index 9d4be6d..6868b7b 100644 --- a/internal/wire/cache_hooks.go +++ b/internal/wire/loader_validation.go @@ -15,27 +15,19 @@ package wire import ( - "encoding/json" - "os" -) + "context" -var ( - osCreateTemp = os.CreateTemp - osMkdirAll = os.MkdirAll - osReadFile = os.ReadFile - osRemove = os.Remove - osRemoveAll = os.RemoveAll - osRename = os.Rename - osStat = os.Stat - osTempDir = os.TempDir + "github.com/goforj/wire/internal/loader" +) - jsonMarshal = json.Marshal - jsonUnmarshal = json.Unmarshal +func loaderValidationMode(ctx context.Context, wd string, env []string) bool { + return effectiveLoaderMode(ctx, wd, env) != loader.ModeFallback +} - cacheKeyForPackageFunc = cacheKeyForPackage - detectOutputDirFunc = detectOutputDir - buildCacheFilesFunc = buildCacheFiles - buildCacheFilesFromMetaFunc = buildCacheFilesFromMeta - rootPackageFilesFunc = rootPackageFiles - hashFilesFunc = hashFiles -) +func effectiveLoaderMode(ctx context.Context, wd string, env []string) loader.Mode { + mode := loader.ModeFromEnv(env) + if mode != loader.ModeAuto { + return mode + } + return loader.ModeAuto +} diff --git a/internal/wire/local_export.go b/internal/wire/local_export.go deleted file mode 100644 index f83ed7b..0000000 --- a/internal/wire/local_export.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "crypto/sha256" - "fmt" - "go/token" - "go/types" - "path/filepath" - - "golang.org/x/tools/go/gcexportdata" - "golang.org/x/tools/go/packages" -) - -const localExportVersion = "wire-local-export-v1" - -func localExportKey(wd string, tags string, pkgPath string, shapeHash string) string { - sum := sha256.Sum256([]byte(localExportVersion + "\x00" + packageCacheScope(wd) + "\x00" + tags + "\x00" + pkgPath + "\x00" + shapeHash)) - return fmt.Sprintf("%x", sum[:]) -} - -func localExportPath(key string) string { - return filepath.Join(cacheDir(), key+".iexp") -} - -func localExportPathForFingerprint(wd string, tags string, fp *packageFingerprint) string { - if fp == nil || fp.PkgPath == "" || fp.ShapeHash == "" { - return "" - } - return localExportPath(localExportKey(wd, tags, fp.PkgPath, fp.ShapeHash)) -} - -func localExportExists(wd string, tags string, fp *packageFingerprint) bool { - path := localExportPathForFingerprint(wd, tags, fp) - if path == "" { - return false - } - _, err := osStat(path) - return err == nil -} - -func writeLocalPackageExports(wd string, tags string, pkgs []*packages.Package, fps map[string]*packageFingerprint) { - if len(pkgs) == 0 || len(fps) == 0 { - return - } - moduleRoot := findModuleRoot(wd) - for _, pkg := range pkgs { - if pkg == nil || pkg.Types == nil || pkg.PkgPath == "" { - continue - } - if classifyPackageLocation(moduleRoot, pkg) != "local" { - continue - } - fp := fps[pkg.PkgPath] - path := localExportPathForFingerprint(wd, tags, fp) - if path == "" { - continue - } - writeLocalPackageExportFile(path, pkg.Fset, pkg.Types) - } -} - -func writeLocalPackageExportFile(path string, fset *token.FileSet, pkg *types.Package) { - if path == "" || fset == nil || pkg == nil { - return - } - dir := cacheDir() - if err := osMkdirAll(dir, 0755); err != nil { - return - } - tmp, err := osCreateTemp(dir, filepath.Base(path)+".tmp-") - if err != nil { - return - } - writeErr := gcexportdata.Write(tmp, fset, pkg) - closeErr := tmp.Close() - if writeErr != nil || closeErr != nil { - osRemove(tmp.Name()) - return - } - if err := osRename(tmp.Name(), path); err != nil { - osRemove(tmp.Name()) - } -} diff --git a/internal/wire/local_fastpath.go b/internal/wire/local_fastpath.go deleted file mode 100644 index 89ea402..0000000 --- a/internal/wire/local_fastpath.go +++ /dev/null @@ -1,1031 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "fmt" - "go/ast" - "go/format" - importerpkg "go/importer" - "go/parser" - "go/token" - "go/types" - "io" - "os" - "path/filepath" - "runtime" - "sort" - "strings" - "time" - - "golang.org/x/tools/go/gcexportdata" - "golang.org/x/tools/go/packages" -) - -func tryIncrementalLocalFastPath(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState) ([]GenerateResult, bool, bool, []error) { - if state == nil || state.manifest == nil { - return nil, false, false, nil - } - if !strings.HasSuffix(state.reason, ".shape_mismatch") { - return nil, false, false, nil - } - roots := manifestOutputPkgPaths(state.manifest) - if len(roots) != 1 { - return nil, false, false, nil - } - changed := changedPackagePaths(state.manifest.LocalPackages, state.currentLocal) - if len(changed) != 1 { - return nil, false, false, nil - } - graph, ok := readIncrementalGraph(incrementalGraphKey(wd, opts.Tags, roots)) - if !ok { - return nil, false, false, nil - } - affected := affectedRoots(graph, changed) - if len(affected) != 1 || affected[0] != roots[0] { - return nil, false, false, nil - } - - fastPathStart := time.Now() - loaded, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], changed, state.currentLocal, state.manifest.ExternalPkgs) - if err != nil { - debugf(ctx, "incremental.local_fastpath miss reason=%v", err) - if shouldBypassIncrementalManifestAfterFastPathError(err) { - invalidateIncrementalPreloadState(state) - return nil, true, true, []error{err} - } - return nil, false, false, nil - } - logTiming(ctx, "incremental.local_fastpath.load", fastPathStart) - - generated, errs := generateFromTypedPackages(ctx, loaded, opts) - logTiming(ctx, "incremental.local_fastpath.generate", fastPathStart) - if len(errs) > 0 { - return nil, true, true, errs - } - - snapshot := &incrementalFingerprintSnapshot{ - fingerprints: loaded.fingerprints, - changed: append([]string(nil), changed...), - } - loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: opts.Tags, - fset: loaded.fset, - fingerprints: snapshot, - loaded: make(map[string]*packages.Package, len(loaded.byPath)), - } - for path, pkg := range loaded.byPath { - loader.loaded[path] = pkg - } - changedSet := make(map[string]struct{}, len(snapshot.changed)) - for _, path := range snapshot.changed { - changedSet[path] = struct{}{} - } - currentPackages := loaded.currentPackages() - writeIncrementalFingerprints(snapshot, wd, opts.Tags) - writeLocalPackageExports(wd, opts.Tags, currentPackages, loaded.fingerprints) - writeIncrementalPackageSummariesWithSummary(loader, currentPackages, newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage), changedSet) - writeIncrementalManifestFromState(wd, env, patterns, opts, state, snapshot, generated) - writeIncrementalGraphFromSnapshot(wd, opts.Tags, roots, loaded.fingerprints) - - debugf(ctx, "incremental.local_fastpath hit root=%s changed=%s", roots[0], strings.Join(changed, ",")) - return generated, true, false, nil -} - -func validateIncrementalTouchedPackages(ctx context.Context, wd string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot) error { - if state == nil || state.manifest == nil || snapshot == nil || len(snapshot.touched) == 0 { - return nil - } - roots := manifestOutputPkgPaths(state.manifest) - if len(roots) != 1 { - return nil - } - _, err := loadLocalPackagesForFastPath(ctx, wd, opts.Tags, roots[0], snapshot.touched, snapshotPackageFingerprints(snapshot), state.manifest.ExternalPkgs) - return err -} - -func shouldBypassIncrementalManifestAfterFastPathError(err error) bool { - if err == nil { - return false - } - msg := err.Error() - if strings.Contains(msg, "missing external export data for ") { - return false - } - return strings.Contains(msg, "type-check failed for ") -} - -func invalidateIncrementalPreloadState(state *incrementalPreloadState) { - if state == nil { - return - } - removeIncrementalManifestFile(state.selectorKey) -} - -func formatLocalTypeCheckError(wd string, pkgPath string, errs []packages.Error) error { - if len(errs) == 0 { - return fmt.Errorf("type-check failed for %s", pkgPath) - } - root := findModuleRoot(wd) - lines := []string{} - for _, pkgErr := range errs { - details := normalizeErrorLines(pkgErr.Msg, root) - if len(details) == 0 { - continue - } - lines = append(lines, fmt.Sprintf("type-check failed for %s: %s", pkgPath, details[0])) - for _, line := range details[1:] { - lines = append(lines, line) - } - } - if len(lines) == 0 { - lines = append(lines, fmt.Sprintf("type-check failed for %s", pkgPath)) - } - return fmt.Errorf("%s", strings.Join(lines, "\n")) -} - -func normalizeErrorLines(msg string, root string) []string { - msg = strings.TrimSpace(msg) - if msg == "" { - return []string{"unknown error"} - } - lines := unfoldTypeCheckChain(msg) - for i := range lines { - lines[i] = relativizeErrorLine(lines[i], root) - } - if len(lines) == 0 { - return []string{"unknown error"} - } - return lines -} - -func relativizeErrorLine(line string, root string) string { - if root == "" { - return line - } - cleanRoot := filepath.Clean(root) - prefix := cleanRoot + string(os.PathSeparator) - return strings.ReplaceAll(line, prefix, "") -} - -func unfoldTypeCheckChain(msg string) []string { - msg = strings.TrimSpace(msg) - if msg == "" { - return nil - } - if inner, outer, ok := splitNestedTypeCheck(msg); ok { - lines := []string{strings.TrimSpace(outer)} - return append(lines, unfoldTypeCheckChain(inner)...) - } - parts := strings.Split(msg, "\n") - lines := make([]string, 0, len(parts)) - for _, part := range parts { - part = strings.TrimSpace(part) - if part == "" { - continue - } - lines = append(lines, part) - } - return lines -} - -func splitNestedTypeCheck(msg string) (inner string, outer string, ok bool) { - msg = strings.TrimSpace(msg) - if len(msg) < 2 || msg[len(msg)-1] != ')' { - return "", "", false - } - depth := 0 - for i := len(msg) - 1; i >= 0; i-- { - switch msg[i] { - case ')': - depth++ - case '(': - depth-- - if depth == 0 { - inner = strings.TrimSpace(msg[i+1 : len(msg)-1]) - if strings.HasPrefix(inner, "type-check failed for ") { - return inner, strings.TrimSpace(msg[:i]), true - } - return "", "", false - } - } - } - return "", "", false -} - -type localFastPathLoaded struct { - fset *token.FileSet - root *packages.Package - allPackages []*packages.Package - byPath map[string]*packages.Package - fingerprints map[string]*packageFingerprint - loader *localFastPathLoader -} - -func (l *localFastPathLoaded) currentPackages() []*packages.Package { - if l == nil { - return nil - } - if l.loader == nil || len(l.loader.pkgs) == 0 { - return l.allPackages - } - all := make([]*packages.Package, 0, len(l.loader.pkgs)) - for _, pkg := range l.loader.pkgs { - all = append(all, pkg) - } - sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) - return all -} - -type localFastPathLoader struct { - ctx context.Context - wd string - tags string - fset *token.FileSet - modulePrefix string - rootPkgPath string - changedPkgs map[string]struct{} - sourcePkgs map[string]struct{} - summaries map[string]*packageSummary - meta map[string]*packageFingerprint - pkgs map[string]*packages.Package - imported map[string]*types.Package - externalMeta map[string]externalPackageExport - localExports map[string]string - externalImp types.Importer - externalFallback types.Importer -} - -func loadLocalPackagesForFastPath(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport) (*localFastPathLoaded, error) { - return loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, changed, current, external, false) -} - -func validateTouchedPackagesFastPath(ctx context.Context, wd string, tags string, touched []string, current []packageFingerprint, external []externalPackageExport) error { - if len(touched) == 0 { - return nil - } - rootPkgPath := touched[0] - _, err := loadLocalPackagesForFastPathMode(ctx, wd, tags, rootPkgPath, touched, current, external, true) - return err -} - -func loadLocalPackagesForFastPathMode(ctx context.Context, wd string, tags string, rootPkgPath string, changed []string, current []packageFingerprint, external []externalPackageExport, validationOnly bool) (*localFastPathLoaded, error) { - meta := fingerprintsFromSlice(current) - if len(meta) == 0 { - return nil, fmt.Errorf("no local fingerprints") - } - if meta[rootPkgPath] == nil { - return nil, fmt.Errorf("missing root package fingerprint") - } - externalMeta := make(map[string]externalPackageExport, len(external)) - for _, item := range external { - if item.PkgPath == "" || item.ExportFile == "" { - continue - } - if meta[item.PkgPath] != nil { - continue - } - externalMeta[item.PkgPath] = item - } - loader := &localFastPathLoader{ - ctx: ctx, - wd: wd, - tags: tags, - fset: token.NewFileSet(), - modulePrefix: moduleImportPrefix(meta), - rootPkgPath: rootPkgPath, - changedPkgs: make(map[string]struct{}, len(changed)), - sourcePkgs: make(map[string]struct{}), - summaries: make(map[string]*packageSummary), - meta: meta, - pkgs: make(map[string]*packages.Package, len(meta)), - imported: make(map[string]*types.Package, len(meta)+len(externalMeta)), - externalMeta: externalMeta, - localExports: make(map[string]string), - } - for _, path := range changed { - loader.changedPkgs[path] = struct{}{} - } - if validationOnly { - for path := range loader.changedPkgs { - loader.sourcePkgs[path] = struct{}{} - } - } else { - loader.markSourceClosure() - } - for path, fp := range meta { - if path == rootPkgPath { - continue - } - if _, changed := loader.changedPkgs[path]; changed { - continue - } - if _, ok := loader.sourcePkgs[path]; ok { - continue - } - if exportPath := localExportPathForFingerprint(wd, tags, fp); exportPath != "" && localExportExists(wd, tags, fp) { - loader.localExports[path] = exportPath - } - } - candidates := make(map[string]*packageSummary) - for path, fp := range meta { - if path == rootPkgPath { - continue - } - if _, changed := loader.changedPkgs[path]; changed { - continue - } - summary, ok := readIncrementalPackageSummary(incrementalSummaryKey(wd, tags, path)) - if !ok || summary == nil || summary.ShapeHash != fp.ShapeHash { - continue - } - candidates[path] = summary - } - loader.summaries = filterSupportedPackageSummaries(candidates) - loader.externalImp = importerpkg.ForCompiler(loader.fset, "gc", loader.openExternalExport) - loader.externalFallback = importerpkg.ForCompiler(loader.fset, "gc", nil) - var root *packages.Package - if validationOnly { - for _, path := range changed { - pkg, err := loader.load(path) - if err != nil { - return nil, err - } - if root == nil { - root = pkg - } - } - } else { - var err error - root, err = loader.load(rootPkgPath) - if err != nil { - return nil, err - } - } - all := make([]*packages.Package, 0, len(loader.pkgs)) - for _, pkg := range loader.pkgs { - all = append(all, pkg) - } - sort.Slice(all, func(i, j int) bool { return all[i].PkgPath < all[j].PkgPath }) - return &localFastPathLoaded{ - fset: loader.fset, - root: root, - allPackages: all, - byPath: loader.pkgs, - fingerprints: loader.meta, - loader: loader, - }, nil -} - -func (l *localFastPathLoader) load(pkgPath string) (*packages.Package, error) { - if pkg := l.pkgs[pkgPath]; pkg != nil { - return pkg, nil - } - fp := l.meta[pkgPath] - if fp == nil { - return nil, fmt.Errorf("package %s not tracked as local", pkgPath) - } - files := filesFromMeta(fp.Files) - if len(files) == 0 { - return nil, fmt.Errorf("package %s has no files", pkgPath) - } - mode := parser.SkipObjectResolution - if pkgPath == l.rootPkgPath { - mode |= parser.ParseComments - } - syntax := make([]*ast.File, 0, len(files)) - parseStart := time.Now() - for _, name := range files { - file, err := l.parseFileForFastPath(name, mode, pkgPath) - if err != nil { - return nil, err - } - syntax = append(syntax, file) - } - logTiming(l.ctx, "incremental.local_fastpath.parse", parseStart) - if len(syntax) == 0 { - return nil, fmt.Errorf("package %s parsed no files", pkgPath) - } - - pkgName := syntax[0].Name.Name - info := newFastPathTypesInfo(pkgPath == l.rootPkgPath) - pkg := &packages.Package{ - Fset: l.fset, - Name: pkgName, - PkgPath: pkgPath, - GoFiles: append([]string(nil), files...), - CompiledGoFiles: append([]string(nil), files...), - Syntax: syntax, - TypesInfo: info, - Imports: make(map[string]*packages.Package), - } - l.pkgs[pkgPath] = pkg - - conf := &types.Config{ - Importer: importerFunc(func(path string) (*types.Package, error) { - return l.importPackage(path) - }), - IgnoreFuncBodies: l.shouldIgnoreFuncBodies(pkgPath), - Sizes: types.SizesFor("gc", runtime.GOARCH), - Error: func(err error) { - pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) - }, - } - typecheckStart := time.Now() - checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, info) - logTiming(l.ctx, "incremental.local_fastpath.typecheck", typecheckStart) - if checkedPkg != nil { - pkg.Types = checkedPkg - l.imported[pkgPath] = checkedPkg - } - if l.shouldRetryWithoutBodyStripping(pkgPath, pkg.Errors) { - return l.reloadWithoutBodyStripping(pkgPath, files, mode, pkg) - } - if err != nil && len(pkg.Errors) == 0 { - return nil, err - } - if len(pkg.Errors) > 0 { - return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) - } - - imports := packageImportPaths(syntax) - localImports := make([]string, 0, len(imports)) - for _, path := range imports { - if dep := l.pkgs[path]; dep != nil { - pkg.Imports[path] = dep - localImports = append(localImports, path) - } - } - sort.Strings(localImports) - updated := *fp - updated.LocalImports = localImports - updated.Tags = l.tags - updated.WD = filepath.Clean(l.wd) - l.meta[pkgPath] = &updated - return pkg, nil -} - -func (l *localFastPathLoader) parseFileForFastPath(name string, mode parser.Mode, pkgPath string) (*ast.File, error) { - file, err := parser.ParseFile(l.fset, name, nil, mode) - if err != nil { - return nil, err - } - if l.shouldStripFunctionBodies(pkgPath) { - stripFunctionBodies(file) - pruneImportsWithoutTopLevelUse(file) - } - return file, nil -} - -func (l *localFastPathLoader) reloadWithoutBodyStripping(pkgPath string, files []string, mode parser.Mode, pkg *packages.Package) (*packages.Package, error) { - syntax := make([]*ast.File, 0, len(files)) - parseStart := time.Now() - for _, name := range files { - file, err := parser.ParseFile(l.fset, name, nil, mode) - if err != nil { - return nil, err - } - syntax = append(syntax, file) - } - logTiming(l.ctx, "incremental.local_fastpath.parse_retry", parseStart) - pkg.Syntax = syntax - pkg.Errors = nil - pkg.TypesInfo = newFastPathTypesInfo(pkgPath == l.rootPkgPath) - conf := &types.Config{ - Importer: importerFunc(func(path string) (*types.Package, error) { - return l.importPackage(path) - }), - IgnoreFuncBodies: false, - Sizes: types.SizesFor("gc", runtime.GOARCH), - Error: func(err error) { - pkg.Errors = append(pkg.Errors, packages.Error{Msg: err.Error()}) - }, - } - typecheckStart := time.Now() - checkedPkg, err := conf.Check(pkgPath, l.fset, syntax, pkg.TypesInfo) - logTiming(l.ctx, "incremental.local_fastpath.typecheck_retry", typecheckStart) - if checkedPkg != nil { - pkg.Types = checkedPkg - l.imported[pkgPath] = checkedPkg - } - if err != nil && len(pkg.Errors) == 0 { - return nil, err - } - if len(pkg.Errors) > 0 { - return nil, formatLocalTypeCheckError(l.wd, pkgPath, pkg.Errors) - } - return pkg, nil -} - -func (l *localFastPathLoader) shouldRetryWithoutBodyStripping(pkgPath string, errs []packages.Error) bool { - if !l.shouldStripFunctionBodies(pkgPath) || len(errs) == 0 { - return false - } - for _, pkgErr := range errs { - msg := pkgErr.Msg - if strings.Contains(msg, "missing function body") || strings.Contains(msg, "func init must have a body") { - return true - } - } - return false -} - -func (l *localFastPathLoader) importPackage(path string) (*types.Package, error) { - if l.shouldImportFromExport(path) { - pkg, err := l.importExportPackage(path) - if err == nil { - return pkg, nil - } - // Cached local export artifacts are an optimization only. If one is - // missing or corrupted, fall back to source loading for correctness. - if _, ok := l.localExports[path]; ok && l.meta[path] != nil { - delete(l.localExports, path) - pkg, loadErr := l.load(path) - if loadErr == nil { - l.refreshLocalExport(path, pkg) - return pkg.Types, nil - } - return nil, loadErr - } - return nil, err - } - if l.meta[path] != nil { - pkg, err := l.load(path) - if err != nil { - return nil, err - } - l.refreshLocalExport(path, pkg) - return pkg.Types, nil - } - if l.externalImp == nil { - return nil, fmt.Errorf("missing external importer") - } - return l.importExportPackage(path) -} - -func (l *localFastPathLoader) openExternalExport(path string) (io.ReadCloser, error) { - meta, ok := l.externalMeta[path] - if !ok || meta.ExportFile == "" { - if l.meta[path] != nil || l.isLikelyLocalImport(path) { - return nil, fmt.Errorf("missing local export data for %s", path) - } - return nil, fmt.Errorf("missing external export data for %s", path) - } - return os.Open(meta.ExportFile) -} - -func (l *localFastPathLoader) isLikelyLocalImport(path string) bool { - if l == nil || l.modulePrefix == "" { - return false - } - return path == l.modulePrefix || strings.HasPrefix(path, l.modulePrefix+"/") -} - -func moduleImportPrefix(meta map[string]*packageFingerprint) string { - if len(meta) == 0 { - return "" - } - paths := make([]string, 0, len(meta)) - for path := range meta { - paths = append(paths, path) - } - sort.Strings(paths) - prefix := strings.Split(paths[0], "/") - for _, path := range paths[1:] { - parts := strings.Split(path, "/") - n := len(prefix) - if len(parts) < n { - n = len(parts) - } - i := 0 - for i < n && prefix[i] == parts[i] { - i++ - } - prefix = prefix[:i] - if len(prefix) == 0 { - return "" - } - } - return strings.Join(prefix, "/") -} - -func (l *localFastPathLoader) importExportPackage(path string) (*types.Package, error) { - if l == nil { - return nil, fmt.Errorf("missing local fast path loader") - } - if pkg := l.imported[path]; pkg != nil && pkg.Complete() { - return pkg, nil - } - if exportPath := l.localExports[path]; exportPath != "" { - f, err := os.Open(exportPath) - if err != nil { - return nil, err - } - defer f.Close() - pkg, err := gcexportdata.Read(f, l.fset, l.imported, path) - if err != nil { - return nil, err - } - l.imported[path] = pkg - return pkg, nil - } - if l.externalImp == nil { - return nil, fmt.Errorf("missing external importer") - } - pkg, err := l.externalImp.Import(path) - if err != nil { - if l.externalFallback != nil && strings.Contains(err.Error(), "missing external export data for ") { - pkg, fallbackErr := l.externalFallback.Import(path) - if fallbackErr == nil { - l.imported[path] = pkg - return pkg, nil - } - } - return nil, err - } - l.imported[path] = pkg - return pkg, nil -} - -func (l *localFastPathLoader) shouldImportFromExport(pkgPath string) bool { - if l == nil { - return false - } - if _, source := l.sourcePkgs[pkgPath]; source { - return false - } - if _, ok := l.localExports[pkgPath]; ok { - return true - } - _, ok := l.externalMeta[pkgPath] - return ok -} - -func (l *localFastPathLoader) refreshLocalExport(pkgPath string, pkg *packages.Package) { - if l == nil || pkg == nil || pkg.Fset == nil || pkg.Types == nil { - return - } - fp := l.meta[pkgPath] - exportPath := localExportPathForFingerprint(l.wd, l.tags, fp) - if exportPath == "" { - return - } - writeLocalPackageExportFile(exportPath, pkg.Fset, pkg.Types) - l.localExports[pkgPath] = exportPath -} - -func (l *localFastPathLoader) markSourceClosure() { - if l == nil { - return - } - reverse := make(map[string][]string) - for pkgPath, fp := range l.meta { - if fp == nil { - continue - } - for _, imp := range fp.LocalImports { - reverse[imp] = append(reverse[imp], pkgPath) - } - } - queue := make([]string, 0, len(l.changedPkgs)+1) - queue = append(queue, l.rootPkgPath) - for pkgPath := range l.changedPkgs { - queue = append(queue, pkgPath) - } - for len(queue) > 0 { - pkgPath := queue[0] - queue = queue[1:] - if _, seen := l.sourcePkgs[pkgPath]; seen { - continue - } - l.sourcePkgs[pkgPath] = struct{}{} - for _, importer := range reverse[pkgPath] { - if _, seen := l.sourcePkgs[importer]; !seen { - queue = append(queue, importer) - } - } - } -} - -func (l *localFastPathLoader) shouldStripFunctionBodies(pkgPath string) bool { - if l == nil { - return false - } - if pkgPath == l.rootPkgPath { - return false - } - _, changed := l.changedPkgs[pkgPath] - return !changed -} - -func (l *localFastPathLoader) shouldIgnoreFuncBodies(pkgPath string) bool { - return l.shouldStripFunctionBodies(pkgPath) -} - -type importerFunc func(string) (*types.Package, error) - -func (fn importerFunc) Import(path string) (*types.Package, error) { - return fn(path) -} - -func packageImportPaths(files []*ast.File) []string { - seen := make(map[string]struct{}) - var out []string - for _, file := range files { - for _, spec := range file.Imports { - path := strings.Trim(spec.Path.Value, "\"") - if path == "" { - continue - } - if _, ok := seen[path]; ok { - continue - } - seen[path] = struct{}{} - out = append(out, path) - } - } - sort.Strings(out) - return out -} - -func newFastPathTypesInfo(full bool) *types.Info { - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - } - if !full { - return info - } - info.Implicits = make(map[ast.Node]types.Object) - info.Selections = make(map[*ast.SelectorExpr]*types.Selection) - info.Scopes = make(map[ast.Node]*types.Scope) - info.Instances = make(map[*ast.Ident]types.Instance) - return info -} - -func pruneImportsWithoutTopLevelUse(file *ast.File) { - if file == nil || len(file.Imports) == 0 { - return - } - used := usedImportNames(file) - filtered := file.Imports[:0] - for _, spec := range file.Imports { - if spec == nil || spec.Path == nil { - continue - } - name := importName(spec) - if name == "_" || name == "." { - filtered = append(filtered, spec) - continue - } - if _, ok := used[name]; ok { - filtered = append(filtered, spec) - } - } - file.Imports = filtered - for _, decl := range file.Decls { - gen, ok := decl.(*ast.GenDecl) - if !ok || gen.Tok != token.IMPORT { - continue - } - specs := gen.Specs[:0] - for _, spec := range gen.Specs { - importSpec, ok := spec.(*ast.ImportSpec) - if !ok || importSpec.Path == nil { - continue - } - name := importName(importSpec) - if name == "_" || name == "." { - specs = append(specs, spec) - continue - } - if _, ok := used[name]; ok { - specs = append(specs, spec) - } - } - gen.Specs = specs - } -} - -func usedImportNames(file *ast.File) map[string]struct{} { - used := make(map[string]struct{}) - ast.Inspect(file, func(node ast.Node) bool { - sel, ok := node.(*ast.SelectorExpr) - if !ok { - return true - } - ident, ok := sel.X.(*ast.Ident) - if !ok || ident.Name == "" { - return true - } - used[ident.Name] = struct{}{} - return true - }) - return used -} - -func importName(spec *ast.ImportSpec) string { - if spec == nil || spec.Path == nil { - return "" - } - if spec.Name != nil && spec.Name.Name != "" { - return spec.Name.Name - } - path := strings.Trim(spec.Path.Value, "\"") - if path == "" { - return "" - } - if slash := strings.LastIndex(path, "/"); slash >= 0 { - path = path[slash+1:] - } - return path -} - -func generateFromTypedPackages(ctx context.Context, loaded *localFastPathLoaded, opts *GenerateOptions) ([]GenerateResult, []error) { - if loaded == nil { - return nil, []error{fmt.Errorf("missing loaded packages")} - } - root := loaded.root - if root == nil { - return nil, []error{fmt.Errorf("missing root package")} - } - if opts == nil { - opts = &GenerateOptions{} - } - pkgStart := time.Now() - res := GenerateResult{PkgPath: root.PkgPath} - outDir, err := detectOutputDir(root.GoFiles) - logTiming(ctx, "generate.package."+root.PkgPath+".output_dir", pkgStart) - if err != nil { - res.Errs = append(res.Errs, err) - return []GenerateResult{res}, nil - } - res.OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") - - var summary *summaryProviderResolver - if loaded.loader != nil { - summary = newSummaryProviderResolver(ctx, loaded.loader.summaries, loaded.loader.importExportPackage) - } - oc := newObjectCacheWithLoader(loaded.allPackages, nil, nil, summary) - g := newGen(root) - injectorStart := time.Now() - injectorFiles, errs := generateInjectors(oc, g, root) - logTiming(ctx, "generate.package."+root.PkgPath+".injectors", injectorStart) - if len(errs) > 0 { - res.Errs = errs - return []GenerateResult{res}, nil - } - copyStart := time.Now() - copyNonInjectorDecls(g, injectorFiles, root.TypesInfo) - logTiming(ctx, "generate.package."+root.PkgPath+".copy_non_injectors", copyStart) - frameStart := time.Now() - goSrc := g.frame(opts.Tags) - logTiming(ctx, "generate.package."+root.PkgPath+".frame", frameStart) - if len(opts.Header) > 0 { - goSrc = append(opts.Header, goSrc...) - } - formatStart := time.Now() - fmtSrc, err := format.Source(goSrc) - logTiming(ctx, "generate.package."+root.PkgPath+".format", formatStart) - if err != nil { - res.Errs = append(res.Errs, err) - } else { - goSrc = fmtSrc - } - res.Content = goSrc - logTiming(ctx, "generate.package."+root.PkgPath+".total", pkgStart) - return []GenerateResult{res}, nil -} - -func writeIncrementalFingerprints(snapshot *incrementalFingerprintSnapshot, wd string, tags string) { - if snapshot == nil { - return - } - for _, path := range snapshot.changed { - fp := snapshot.fingerprints[path] - if fp == nil { - continue - } - writeIncrementalFingerprint(incrementalFingerprintKey(wd, tags, fp.PkgPath), fp) - } -} - -func writeIncrementalManifestFromState(wd string, env []string, patterns []string, opts *GenerateOptions, state *incrementalPreloadState, snapshot *incrementalFingerprintSnapshot, generated []GenerateResult) { - if snapshot == nil || len(generated) == 0 || state == nil || state.manifest == nil { - return - } - scope := runCacheScope(wd, patterns) - manifest := &incrementalManifest{ - Version: incrementalManifestVersion, - WD: scope, - Tags: opts.Tags, - Prefix: opts.PrefixOutputFile, - HeaderHash: headerHash(opts.Header), - EnvHash: envHash(env), - Patterns: normalizePatternsForScope(wd, packageCacheScope(wd), patterns), - LocalPackages: snapshotPackageFingerprints(snapshot), - ExternalPkgs: append([]externalPackageExport(nil), state.manifest.ExternalPkgs...), - ExternalFiles: append([]cacheFile(nil), state.manifest.ExternalFiles...), - ExtraFiles: extraCacheFiles(wd), - } - for _, out := range generated { - if len(out.Content) == 0 || out.OutputPath == "" { - continue - } - contentKey := incrementalContentKey(out.Content) - writeCache(contentKey, out.Content) - manifest.Outputs = append(manifest.Outputs, incrementalOutput{ - PkgPath: out.PkgPath, - OutputPath: out.OutputPath, - ContentKey: contentKey, - }) - } - if len(manifest.Outputs) == 0 { - return - } - selectorKey := incrementalManifestSelectorKey(wd, env, patterns, opts) - writeIncrementalManifestFile(selectorKey, manifest) - writeIncrementalManifestFile(incrementalManifestStateKey(selectorKey, manifest.LocalPackages), manifest) -} - -func writeIncrementalGraphFromSnapshot(wd string, tags string, roots []string, fps map[string]*packageFingerprint) { - if len(roots) == 0 || len(fps) == 0 { - return - } - graph := &incrementalGraph{ - Version: incrementalGraphVersion, - WD: packageCacheScope(wd), - Tags: tags, - Roots: append([]string(nil), roots...), - LocalReverse: make(map[string][]string), - } - sort.Strings(graph.Roots) - for _, fp := range fps { - if fp == nil { - continue - } - for _, imp := range fp.LocalImports { - graph.LocalReverse[imp] = append(graph.LocalReverse[imp], fp.PkgPath) - } - } - for path := range graph.LocalReverse { - sort.Strings(graph.LocalReverse[path]) - } - writeIncrementalGraph(incrementalGraphKey(wd, tags, graph.Roots), graph) -} - -func manifestOutputPkgPaths(manifest *incrementalManifest) []string { - if manifest == nil || len(manifest.Outputs) == 0 { - return nil - } - seen := make(map[string]struct{}, len(manifest.Outputs)) - paths := make([]string, 0, len(manifest.Outputs)) - for _, out := range manifest.Outputs { - if out.PkgPath == "" { - continue - } - if _, ok := seen[out.PkgPath]; ok { - continue - } - seen[out.PkgPath] = struct{}{} - paths = append(paths, out.PkgPath) - } - sort.Strings(paths) - return paths -} - -func changedPackagePaths(previous []packageFingerprint, current []packageFingerprint) []string { - if len(current) == 0 { - return nil - } - prevByPath := make(map[string]packageFingerprint, len(previous)) - for _, fp := range previous { - prevByPath[fp.PkgPath] = fp - } - changed := make([]string, 0, len(current)) - for _, fp := range current { - prev, ok := prevByPath[fp.PkgPath] - if !ok || !incrementalFingerprintEquivalent(&prev, &fp) { - changed = append(changed, fp.PkgPath) - } - } - sort.Strings(changed) - return changed -} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index a7a1a02..e6f8cb1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -19,6 +19,7 @@ import ( "errors" "fmt" "go/ast" + "go/parser" "go/token" "go/types" "os" @@ -30,6 +31,8 @@ import ( "golang.org/x/tools/go/ast/astutil" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" + + "github.com/goforj/wire/internal/loader" ) // A providerSetSrc captures the source for a type provided by a ProviderSet. @@ -250,11 +253,8 @@ type Field struct { // In case of duplicate environment variables, the last one in the list // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { - if IncrementalEnabled(ctx, env) { - debugf(ctx, "incremental=enabled") - } loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, tags, patterns) + pkgs, errs := load(ctx, wd, env, tags, patterns) logTiming(ctx, "load.packages", loadStart) if len(errs) > 0 { return nil, errs @@ -267,19 +267,13 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCache(pkgs, loader) + oc := newObjectCache(pkgs) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { // The marker function package confuses analysis. continue } - if loaded, errs := oc.ensurePackage(pkg.PkgPath); len(errs) > 0 { - ec.add(errs...) - continue - } else if loaded != nil { - pkg = loaded - } pkgStart := time.Now() scope := pkg.Types.Scope() setStart := time.Now() @@ -367,68 +361,48 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, *lazyLoader, []error) { - var session *incrementalSession +func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, []error) { fset := token.NewFileSet() - if IncrementalEnabled(ctx, env) { - session = getIncrementalSession(wd, env, tags) - fset = session.fset - debugf(ctx, "incremental session=enabled") - } - baseCfg := &packages.Config{ - Context: ctx, - Mode: baseLoadMode(ctx), - Dir: wd, + loaderMode := effectiveLoaderMode(ctx, wd, env) + parseStats := &parseFileStats{} + loadStart := time.Now() + result, err := loader.New().LoadPackages(withLoaderTiming(ctx), loader.PackageLoadRequest{ + WD: wd, Env: env, - BuildFlags: []string{"-tags=wireinject"}, + Tags: tags, + Patterns: append([]string(nil), patterns...), + Mode: packages.LoadAllSyntax, + LoaderMode: loaderMode, Fset: fset, + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + start := time.Now() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + parseStats.record(false, time.Since(start), err, false) + return file, err + }, + }) + logTiming(ctx, "load.packages.load", loadStart) + var typedPkgs []*packages.Package + if result != nil { + typedPkgs = result.Packages + debugf(ctx, "load.packages.backend=%s", result.Backend) + if result.FallbackReason != loader.FallbackReasonNone { + debugf(ctx, "load.packages.fallback_reason=%s", result.FallbackReason) + if result.FallbackDetail != "" { + debugf(ctx, "load.packages.fallback_detail=%s", result.FallbackDetail) + } + } } - if len(tags) > 0 { - baseCfg.BuildFlags[0] += " " + tags - } - escaped := make([]string, len(patterns)) - for i := range patterns { - escaped[i] = "pattern=" + patterns[i] - } - baseLoadStart := time.Now() - pkgs, err := packages.Load(baseCfg, escaped...) - logTiming(ctx, "load.packages.base.load", baseLoadStart) - logLoadDebug(ctx, "base", baseCfg.Mode, strings.Join(patterns, ","), wd, pkgs, nil) + logLoadDebug(ctx, "typed", packages.LoadAllSyntax, strings.Join(patterns, ","), wd, typedPkgs, parseStats) if err != nil { - return nil, nil, []error{err} + return nil, []error{err} } - baseErrsStart := time.Now() - errs := collectLoadErrors(pkgs) - logTiming(ctx, "load.packages.base.collect_errors", baseErrsStart) + errs := collectLoadErrors(typedPkgs) + logTiming(ctx, "load.packages.collect_errors", loadStart) if len(errs) > 0 { - return nil, nil, errs - } - var fingerprints *incrementalFingerprintSnapshot - if !incrementalColdBootstrapEnabled(ctx) { - fingerprints = analyzeIncrementalFingerprints(ctx, wd, env, tags, pkgs) - analyzeIncrementalGraph(ctx, wd, env, tags, pkgs, fingerprints) - } - - baseFiles := collectPackageFiles(pkgs) - loader := &lazyLoader{ - ctx: ctx, - wd: wd, - env: env, - tags: tags, - fset: fset, - baseFiles: baseFiles, - session: session, - fingerprints: fingerprints, - } - return pkgs, loader, nil -} - -func baseLoadMode(ctx context.Context) packages.LoadMode { - mode := packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports - if !incrementalColdBootstrapEnabled(ctx) { - mode |= packages.NeedDeps + return nil, errs } - return mode + return typedPkgs, nil } func collectLoadErrors(pkgs []*packages.Package) []error { @@ -481,8 +455,6 @@ type objectCache struct { packages map[string]*packages.Package objects map[objRef]objCacheEntry hasher typeutil.Hasher - loader *lazyLoader - summary *summaryProviderResolver } type objRef struct { @@ -495,11 +467,7 @@ type objCacheEntry struct { errs []error } -func newObjectCache(pkgs []*packages.Package, loader *lazyLoader) *objectCache { - return newObjectCacheWithLoader(pkgs, loader, nil, nil) -} - -func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *localFastPathLoader, summary *summaryProviderResolver) *objectCache { +func newObjectCache(pkgs []*packages.Package) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } @@ -508,11 +476,6 @@ func newObjectCacheWithLoader(pkgs []*packages.Package, loader *lazyLoader, _ *l packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), hasher: typeutil.MakeHasher(), - loader: loader, - summary: summary, - } - if oc.fset == nil && loader != nil { - oc.fset = loader.fset } // Depth-first search of all dependencies to gather import path to // packages.Package mapping. go/packages guarantees that for a single @@ -546,24 +509,6 @@ func (oc *objectCache) registerPackages(pkgs []*packages.Package, replace bool) } } -func (oc *objectCache) ensurePackage(pkgPath string) (*packages.Package, []error) { - if pkg := oc.packages[pkgPath]; pkg != nil && pkg.TypesInfo != nil && len(pkg.Syntax) > 0 { - return pkg, nil - } - if oc.loader == nil { - if pkg := oc.packages[pkgPath]; pkg != nil { - return pkg, nil - } - return nil, []error{fmt.Errorf("package %q is missing type information", pkgPath)} - } - loaded, errs := oc.loader.load(pkgPath) - if len(errs) > 0 { - return nil, errs - } - oc.registerPackages(loaded, true) - return oc.packages[pkgPath], nil -} - // get converts a Go object into a Wire structure. It may return a *Provider, an // *IfaceBinding, a *ProviderSet, a *Value, or a []*Field. func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { @@ -582,14 +527,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { }() switch obj := obj.(type) { case *types.Var: - if isProviderSetType(obj.Type()) && oc.summary != nil { - if pset, ok, summaryErrs := oc.summary.Resolve(obj.Pkg().Path(), obj.Name()); ok { - return pset, summaryErrs - } - } - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } spec := oc.varDecl(obj) if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} @@ -605,9 +542,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Func: return processFuncProvider(oc.fset, obj) default: - if _, errs := oc.ensurePackage(ref.importPath); len(errs) > 0 { - return nil, errs - } return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 516d1d5..7c7a3b7 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -333,18 +333,18 @@ func TestAllFields(t *testing.T) { } } -func TestObjectCacheEnsurePackage(t *testing.T) { +func TestNewObjectCacheRegistersPackages(t *testing.T) { t.Parallel() fset := token.NewFileSet() pkg := &packages.Package{PkgPath: "example.com/p", Fset: fset} - oc := newObjectCache([]*packages.Package{pkg}, nil) + oc := newObjectCache([]*packages.Package{pkg}) - if got, errs := oc.ensurePackage(pkg.PkgPath); len(errs) != 0 || got != pkg { - t.Fatalf("expected existing package without errors, got pkg=%v errs=%v", got, errs) + if got := oc.packages[pkg.PkgPath]; got != pkg { + t.Fatalf("expected package to be registered, got %v", got) } - if _, errs := oc.ensurePackage("missing.example.com"); len(errs) == 0 { - t.Fatal("expected missing package error") + if got := oc.packages["missing.example.com"]; got != nil { + t.Fatalf("expected missing package to remain absent, got %v", got) } } diff --git a/internal/wire/parser_lazy_loader.go b/internal/wire/parser_lazy_loader.go deleted file mode 100644 index f6137bc..0000000 --- a/internal/wire/parser_lazy_loader.go +++ /dev/null @@ -1,188 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/parser" - "go/token" - "path/filepath" - "time" - - "golang.org/x/tools/go/packages" -) - -type lazyLoader struct { - ctx context.Context - wd string - env []string - tags string - fset *token.FileSet - baseFiles map[string]map[string]struct{} - session *incrementalSession - fingerprints *incrementalFingerprintSnapshot - loaded map[string]*packages.Package -} - -func collectPackageFiles(pkgs []*packages.Package) map[string]map[string]struct{} { - all := collectAllPackages(pkgs) - out := make(map[string]map[string]struct{}, len(all)) - for path, pkg := range all { - if pkg == nil { - continue - } - files := make(map[string]struct{}, len(pkg.CompiledGoFiles)) - for _, name := range pkg.CompiledGoFiles { - files[filepath.Clean(name)] = struct{}{} - } - if len(files) > 0 { - out[path] = files - } - } - return out -} - -func collectAllPackages(pkgs []*packages.Package) map[string]*packages.Package { - all := make(map[string]*packages.Package) - stack := append([]*packages.Package(nil), pkgs...) - for len(stack) > 0 { - p := stack[len(stack)-1] - stack = stack[:len(stack)-1] - if p == nil || all[p.PkgPath] != nil { - continue - } - all[p.PkgPath] = p - for _, imp := range p.Imports { - stack = append(stack, imp) - } - } - return all -} - -func (ll *lazyLoader) load(pkgPath string) ([]*packages.Package, []error) { - return ll.loadWithMode(pkgPath, ll.fullMode(), "load.packages.lazy.load") -} - -func (ll *lazyLoader) fullMode() packages.LoadMode { - return packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile -} - -func (ll *lazyLoader) loadWithMode(pkgPath string, mode packages.LoadMode, timingLabel string) ([]*packages.Package, []error) { - parseStats := &parseFileStats{} - cfg := &packages.Config{ - Context: ll.ctx, - Mode: mode, - Dir: ll.wd, - Env: ll.env, - BuildFlags: []string{"-tags=wireinject"}, - Fset: ll.fset, - ParseFile: ll.parseFileFor(pkgPath, parseStats), - } - if len(ll.tags) > 0 { - cfg.BuildFlags[0] += " " + ll.tags - } - loadStart := time.Now() - pkgs, err := packages.Load(cfg, "pattern="+pkgPath) - logTiming(ll.ctx, timingLabel, loadStart) - logLoadDebug(ll.ctx, "lazy", mode, pkgPath, ll.wd, pkgs, parseStats) - if err != nil { - return nil, []error{err} - } - errs := collectLoadErrors(pkgs) - if len(errs) > 0 { - return nil, errs - } - ll.rememberPackages(pkgs) - return pkgs, nil -} - -func (ll *lazyLoader) rememberPackages(pkgs []*packages.Package) { - if ll == nil || len(pkgs) == 0 { - return - } - if ll.loaded == nil { - ll.loaded = make(map[string]*packages.Package) - } - for path, pkg := range collectAllPackages(pkgs) { - if pkg != nil { - ll.loaded[path] = pkg - } - } -} - -func (ll *lazyLoader) parseFileFor(pkgPath string, stats *parseFileStats) func(*token.FileSet, string, []byte) (*ast.File, error) { - primary := primaryFileSet(ll.baseFiles[pkgPath]) - return func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - start := time.Now() - isPrimary := isPrimaryFile(primary, filename) - keepBodies := ll.shouldKeepDependencyBodies(filename) - if !isPrimary && !keepBodies && ll.session != nil { - if file, ok := ll.session.getParsedDep(filename, src); ok { - if stats != nil { - stats.record(false, time.Since(start), nil, true) - } - return file, nil - } - } - mode := parser.SkipObjectResolution - if isPrimary { - mode = parser.ParseComments | parser.SkipObjectResolution - } - file, err := parser.ParseFile(fset, filename, src, mode) - if stats != nil { - stats.record(isPrimary, time.Since(start), err, false) - } - if err != nil { - return nil, err - } - if primary == nil { - return file, nil - } - if isPrimary { - return file, nil - } - if keepBodies { - return file, nil - } - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - fn.Body = nil - fn.Doc = nil - } - } - if ll.session != nil { - ll.session.storeParsedDep(filename, src, file) - } - return file, nil - } -} - -func (ll *lazyLoader) shouldKeepDependencyBodies(filename string) bool { - if ll == nil || ll.fingerprints == nil || len(ll.fingerprints.touched) == 0 { - return false - } - clean := filepath.Clean(filename) - for _, pkgPath := range ll.fingerprints.touched { - files := ll.baseFiles[pkgPath] - if len(files) == 0 { - continue - } - if _, ok := files[clean]; ok { - return true - } - } - return false -} diff --git a/internal/wire/parser_lazy_loader_test.go b/internal/wire/parser_lazy_loader_test.go deleted file mode 100644 index 86b49da..0000000 --- a/internal/wire/parser_lazy_loader_test.go +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "go/ast" - "go/token" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestLazyLoaderParseFileFor(t *testing.T) { - t.Helper() - fset := token.NewFileSet() - pkgPath := "example.com/pkg" - root := t.TempDir() - primary := filepath.Join(root, "primary.go") - secondary := filepath.Join(root, "secondary.go") - ll := &lazyLoader{ - fset: fset, - baseFiles: map[string]map[string]struct{}{ - pkgPath: {filepath.Clean(primary): {}}, - }, - } - src := strings.Join([]string{ - "package pkg", - "", - "// Doc comment", - "func Foo() {", - "\tprintln(\"hi\")", - "}", - "", - }, "\n") - - parse := ll.parseFileFor(pkgPath, &parseFileStats{}) - file, err := parse(fset, primary, []byte(src)) - if err != nil { - t.Fatalf("parse primary: %v", err) - } - fn := firstFuncDecl(t, file) - if fn.Body == nil { - t.Fatal("expected primary file to keep function body") - } - if fn.Doc == nil { - t.Fatal("expected primary file to keep doc comment") - } - - file, err = parse(fset, secondary, []byte(src)) - if err != nil { - t.Fatalf("parse secondary: %v", err) - } - fn = firstFuncDecl(t, file) - if fn.Body != nil { - t.Fatal("expected secondary file to strip function body") - } - if fn.Doc != nil { - t.Fatal("expected secondary file to strip doc comment") - } -} - -func TestLazyLoaderParseFileForCachesDependencyFiles(t *testing.T) { - t.Helper() - fset := token.NewFileSet() - pkgPath := "example.com/pkg" - root := t.TempDir() - primary := filepath.Join(root, "primary.go") - secondary := filepath.Join(root, "secondary.go") - session := &incrementalSession{ - fset: fset, - parsedDeps: make(map[string]cachedParsedFile), - } - ll := &lazyLoader{ - fset: fset, - baseFiles: map[string]map[string]struct{}{ - pkgPath: {filepath.Clean(primary): {}}, - }, - session: session, - } - src := []byte(strings.Join([]string{ - "package pkg", - "", - "func Foo() {", - "\tprintln(\"hi\")", - "}", - "", - }, "\n")) - - stats1 := &parseFileStats{} - parse1 := ll.parseFileFor(pkgPath, stats1) - file1, err := parse1(fset, secondary, src) - if err != nil { - t.Fatalf("first parse: %v", err) - } - snap1 := stats1.snapshot() - if snap1.cacheHits != 0 || snap1.cacheMisses != 1 { - t.Fatalf("first parse stats = %+v, want 0 hits and 1 miss", snap1) - } - - stats2 := &parseFileStats{} - parse2 := ll.parseFileFor(pkgPath, stats2) - file2, err := parse2(fset, secondary, src) - if err != nil { - t.Fatalf("second parse: %v", err) - } - if file1 != file2 { - t.Fatal("expected cached dependency parse to reuse AST") - } - snap2 := stats2.snapshot() - if snap2.cacheHits != 1 || snap2.cacheMisses != 0 { - t.Fatalf("second parse stats = %+v, want 1 hit and 0 misses", snap2) - } -} - -func TestLoadModuleUsesWireinjectTagsForDeps(t *testing.T) { - repoRoot := mustRepoRoot(t) - root := t.TempDir() - - writeFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ - "module example.com/app", - "", - "go 1.19", - "", - "require github.com/goforj/wire v0.0.0", - "replace github.com/goforj/wire => " + repoRoot, - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package app", - "", - "import (", - "\t\"example.com/app/dep\"", - "\t\"github.com/goforj/wire\"", - ")", - "", - "func Init() *dep.Foo {", - "\twire.Build(dep.New)", - "\treturn nil", - "}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ - "package dep", - "", - "type Foo struct{}", - "", - }, "\n")) - - writeFile(t, filepath.Join(root, "dep", "dep_inject.go"), strings.Join([]string{ - "//go:build wireinject", - "// +build wireinject", - "", - "package dep", - "", - "func New() *Foo {", - "\treturn &Foo{}", - "}", - "", - }, "\n")) - - env := append(os.Environ(), "GOWORK=off") - ctx := context.Background() - - info, errs := Load(ctx, root, env, "", []string{"./app"}) - if len(errs) > 0 { - t.Fatalf("Load returned errors: %v", errs) - } - if info == nil { - t.Fatal("Load returned nil info") - } - if len(info.Injectors) != 1 || info.Injectors[0].FuncName != "Init" { - t.Fatalf("Load returned unexpected injectors: %+v", info.Injectors) - } -} - -func firstFuncDecl(t *testing.T, file *ast.File) *ast.FuncDecl { - t.Helper() - for _, decl := range file.Decls { - if fn, ok := decl.(*ast.FuncDecl); ok { - return fn - } - } - t.Fatal("expected function declaration in file") - return nil -} diff --git a/internal/wire/summary_provider_resolver.go b/internal/wire/summary_provider_resolver.go deleted file mode 100644 index c93e0c5..0000000 --- a/internal/wire/summary_provider_resolver.go +++ /dev/null @@ -1,223 +0,0 @@ -// Copyright 2026 The Wire Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package wire - -import ( - "context" - "fmt" - "go/token" - "go/types" - "time" - - "golang.org/x/tools/go/types/typeutil" -) - -type summaryProviderResolver struct { - ctx context.Context - fset *token.FileSet - summaries map[string]*packageSummary - importPackage func(string) (*types.Package, error) - cache map[providerSetRefSummary]*ProviderSet - resolving map[providerSetRefSummary]struct{} - supported map[string]bool -} - -func newSummaryProviderResolver(ctx context.Context, summaries map[string]*packageSummary, importPackage func(string) (*types.Package, error)) *summaryProviderResolver { - if len(summaries) == 0 || importPackage == nil { - return nil - } - r := &summaryProviderResolver{ - ctx: ctx, - fset: token.NewFileSet(), - summaries: make(map[string]*packageSummary, len(summaries)), - importPackage: importPackage, - cache: make(map[providerSetRefSummary]*ProviderSet), - resolving: make(map[providerSetRefSummary]struct{}), - supported: make(map[string]bool, len(summaries)), - } - for pkgPath, summary := range summaries { - if summary == nil { - continue - } - r.summaries[pkgPath] = summary - } - for pkgPath := range r.summaries { - r.supported[pkgPath] = r.packageSupported(pkgPath, make(map[string]struct{})) - } - return r -} - -func filterSupportedPackageSummaries(summaries map[string]*packageSummary) map[string]*packageSummary { - if len(summaries) == 0 { - return nil - } - resolver := &summaryProviderResolver{ - summaries: summaries, - supported: make(map[string]bool, len(summaries)), - } - out := make(map[string]*packageSummary) - for pkgPath, summary := range summaries { - if summary == nil { - continue - } - if resolver.packageSupported(pkgPath, make(map[string]struct{})) { - out[pkgPath] = summary - } - } - return out -} - -func (r *summaryProviderResolver) Resolve(pkgPath string, varName string) (*ProviderSet, bool, []error) { - if r == nil || !r.supported[pkgPath] { - return nil, false, nil - } - start := time.Now() - set, err := r.resolve(providerSetRefSummary{PkgPath: pkgPath, VarName: varName}) - logTiming(r.ctx, "incremental.local_fastpath.summary_resolve", start) - if err != nil { - return nil, true, []error{err} - } - return set, true, nil -} - -func (r *summaryProviderResolver) resolve(ref providerSetRefSummary) (*ProviderSet, error) { - if set := r.cache[ref]; set != nil { - return set, nil - } - if _, ok := r.resolving[ref]; ok { - return nil, fmt.Errorf("summary provider set cycle for %s.%s", ref.PkgPath, ref.VarName) - } - summary := r.summaries[ref.PkgPath] - if summary == nil { - return nil, fmt.Errorf("missing package summary for %s", ref.PkgPath) - } - setSummary, ok := r.findProviderSet(summary, ref.VarName) - if !ok { - return nil, fmt.Errorf("missing provider set summary for %s.%s", ref.PkgPath, ref.VarName) - } - r.resolving[ref] = struct{}{} - defer delete(r.resolving, ref) - - pkg, err := r.importPackage(ref.PkgPath) - if err != nil { - return nil, err - } - set := &ProviderSet{ - PkgPath: ref.PkgPath, - VarName: ref.VarName, - } - for _, provider := range setSummary.Providers { - resolved, err := r.resolveProvider(pkg, provider) - if err != nil { - return nil, err - } - set.Providers = append(set.Providers, resolved) - } - for _, imported := range setSummary.Imports { - child, err := r.resolve(imported) - if err != nil { - return nil, err - } - set.Imports = append(set.Imports, child) - } - hasher := typeutil.MakeHasher() - providerMap, srcMap, errs := buildProviderMap(r.fset, hasher, set) - if len(errs) > 0 { - return nil, errs[0] - } - if errs := verifyAcyclic(providerMap, hasher); len(errs) > 0 { - return nil, errs[0] - } - set.providerMap = providerMap - set.srcMap = srcMap - r.cache[ref] = set - return set, nil -} - -func (r *summaryProviderResolver) resolveProvider(pkg *types.Package, summary providerSummary) (*Provider, error) { - if summary.IsStruct || len(summary.Out) == 0 { - return nil, fmt.Errorf("unsupported summary provider %s.%s", summary.PkgPath, summary.Name) - } - if pkg == nil || pkg.Path() != summary.PkgPath { - var err error - pkg, err = r.importPackage(summary.PkgPath) - if err != nil { - return nil, err - } - } - obj := pkg.Scope().Lookup(summary.Name) - fn, ok := obj.(*types.Func) - if !ok { - return nil, fmt.Errorf("summary provider %s.%s missing function", summary.PkgPath, summary.Name) - } - provider, errs := processFuncProvider(r.fset, fn) - if len(errs) > 0 { - return nil, errs[0] - } - return provider, nil -} - -func (r *summaryProviderResolver) findProviderSet(summary *packageSummary, varName string) (providerSetSummary, bool) { - if summary == nil { - return providerSetSummary{}, false - } - for _, set := range summary.ProviderSets { - if set.VarName == varName { - return set, true - } - } - return providerSetSummary{}, false -} - -func (r *summaryProviderResolver) packageSupported(pkgPath string, visiting map[string]struct{}) bool { - if ok, seen := r.supported[pkgPath]; seen { - return ok - } - if _, seen := visiting[pkgPath]; seen { - return false - } - summary := r.summaries[pkgPath] - if summary == nil { - return false - } - visiting[pkgPath] = struct{}{} - defer delete(visiting, pkgPath) - for _, set := range summary.ProviderSets { - if !providerSetSummarySupported(set) { - return false - } - for _, imported := range set.Imports { - if _, ok := r.summaries[imported.PkgPath]; !ok { - return false - } - if !r.packageSupported(imported.PkgPath, visiting) { - return false - } - } - } - return true -} - -func providerSetSummarySupported(summary providerSetSummary) bool { - if len(summary.Bindings) > 0 || len(summary.Values) > 0 || len(summary.Fields) > 0 || len(summary.InputTypes) > 0 { - return false - } - for _, provider := range summary.Providers { - if provider.IsStruct { - return false - } - } - return true -} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 24ca575..09bf814 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "go/ast" + "go/format" "go/printer" "go/token" "go/types" @@ -101,75 +102,69 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - var preloadState *incrementalPreloadState - bypassIncrementalManifest := false - coldBootstrap := false - if IncrementalEnabled(ctx, env) { - debugf(ctx, "incremental=enabled") - preloadState, _ = prepareIncrementalPreloadState(ctx, wd, env, patterns, opts) - coldBootstrap = preloadState == nil - if coldBootstrap { - ctx = withIncrementalColdBootstrap(ctx, true) - } - if cached, ok := readPreloadIncrementalManifestResultsFromState(ctx, wd, env, patterns, opts, preloadState, preloadState != nil); ok { - return cached, nil - } - if generated, ok, bypass, errs := tryIncrementalLocalFastPath(ctx, wd, env, patterns, opts, preloadState); ok || len(errs) > 0 { - return generated, errs - } else if bypass { - bypassIncrementalManifest = true - } - } - if cached, ok := readManifestResults(wd, env, patterns, opts); ok { - return cached, nil - } loadStart := time.Now() - pkgs, loader, errs := load(ctx, wd, env, opts.Tags, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) logTiming(ctx, "generate.load", loadStart) if len(errs) > 0 { return nil, errs } - if err := validateIncrementalTouchedPackages(ctx, wd, opts, preloadState, loader.fingerprints); err != nil { - if shouldBypassIncrementalManifestAfterFastPathError(err) { - return nil, []error{err} - } - bypassIncrementalManifest = true - } - if !bypassIncrementalManifest { - if cached, ok := readIncrementalManifestResults(ctx, wd, env, patterns, opts, pkgs, loader.fingerprints); ok { - warmPackageOutputCache(pkgs, opts, cached) - return cached, nil - } - } else { - debugf(ctx, "incremental.manifest bypass reason=fastpath_error") - ctx = withBypassPackageCache(ctx) - } generated := make([]GenerateResult, len(pkgs)) for i, pkg := range pkgs { - generated[i] = generateForPackage(ctx, pkg, loader, opts) - } - if allGeneratedOK(generated) { - if IncrementalEnabled(ctx, env) { - if coldBootstrap { - snapshot := buildIncrementalManifestSnapshotFromPackages(wd, opts.Tags, incrementalManifestPackages(pkgs, loader)) - writeIncrementalManifestWithOptions(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), snapshot, generated, false) - if snapshot != nil { - writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), snapshot.fingerprints) - writeIncrementalGraphFromSnapshot(wd, opts.Tags, manifestOutputPkgPathsFromGenerated(generated), snapshot.fingerprints) - loader.fingerprints = snapshot - } - writeIncrementalPackageSummaries(loader, pkgs) - } else { - writeLocalPackageExports(wd, opts.Tags, incrementalManifestPackages(pkgs, loader), loader.fingerprints.fingerprints) - writeIncrementalPackageSummaries(loader, pkgs) - writeIncrementalManifest(wd, env, patterns, opts, incrementalManifestPackages(pkgs, loader), loader.fingerprints, generated) - } + pkgStart := time.Now() + generated[i].PkgPath = pkg.PkgPath + dirStart := time.Now() + outDir, err := detectOutputDir(pkg.GoFiles) + logTiming(ctx, "generate.package."+pkg.PkgPath+".output_dir", dirStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + continue + } + generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") + g := newGen(pkg) + oc := newObjectCache([]*packages.Package{pkg}) + injectorStart := time.Now() + injectorFiles, genErrs := generateInjectors(oc, g, pkg) + logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) + if len(genErrs) > 0 { + generated[i].Errs = genErrs + continue } - writeManifest(wd, env, patterns, opts, pkgs) + copyStart := time.Now() + copyNonInjectorDecls(g, injectorFiles, pkg.TypesInfo) + logTiming(ctx, "generate.package."+pkg.PkgPath+".copy_non_injectors", copyStart) + frameStart := time.Now() + goSrc := g.frame(opts.Tags) + logTiming(ctx, "generate.package."+pkg.PkgPath+".frame", frameStart) + if len(opts.Header) > 0 { + goSrc = append(opts.Header, goSrc...) + } + formatStart := time.Now() + fmtSrc, err := format.Source(goSrc) + logTiming(ctx, "generate.package."+pkg.PkgPath+".format", formatStart) + if err != nil { + generated[i].Errs = append(generated[i].Errs, err) + } else { + goSrc = fmtSrc + } + generated[i].Content = goSrc + logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) } return generated, nil } +func detectOutputDir(paths []string) (string, error) { + if len(paths) == 0 { + return "", fmt.Errorf("no files to derive output directory from") + } + dir := filepath.Dir(paths[0]) + for _, p := range paths[1:] { + if dir2 := filepath.Dir(p); dir2 != dir { + return "", fmt.Errorf("found conflicting directories %q and %q", dir, dir2) + } + } + return dir, nil +} + func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { if len(generated) == 0 { return nil @@ -190,46 +185,6 @@ func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { return out } -func warmPackageOutputCache(pkgs []*packages.Package, opts *GenerateOptions, generated []GenerateResult) { - if len(pkgs) == 0 || len(generated) == 0 { - return - } - byPkg := make(map[string][]byte, len(generated)) - for _, gen := range generated { - if len(gen.Content) == 0 { - continue - } - byPkg[gen.PkgPath] = gen.Content - } - for _, pkg := range pkgs { - content := byPkg[pkg.PkgPath] - if len(content) == 0 { - continue - } - key, err := cacheKeyForPackage(pkg, opts) - if err != nil || key == "" { - continue - } - writeCache(key, content) - } -} - -func incrementalManifestPackages(pkgs []*packages.Package, loader *lazyLoader) []*packages.Package { - if loader == nil || len(loader.loaded) == 0 { - return pkgs - } - out := make([]*packages.Package, 0, len(loader.loaded)) - for _, pkg := range loader.loaded { - if pkg != nil { - out = append(out, pkg) - } - } - if len(out) == 0 { - return pkgs - } - return out -} - // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index cb167aa..dc5cfda 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -26,6 +26,7 @@ import ( "io/ioutil" "os" "os/exec" + "path" "path/filepath" "strings" "testing" @@ -481,6 +482,7 @@ func isIdent(s string) bool { // "C:\GOPATH" and running on Windows, the string // "C:\GOPATH\src\foo\bar.go:15:4" would be rewritten to "foo/bar.go:x:y". func scrubError(gopath string, s string) string { + s = normalizeHeaderRelativeError(s) sb := new(strings.Builder) query := gopath + string(os.PathSeparator) + "src" + string(os.PathSeparator) for { @@ -517,7 +519,106 @@ func scrubError(gopath string, s string) string { sb.WriteString(linecol) s = s[linecolLen:] } - return sb.String() + return strings.TrimRight(sb.String(), "\n") +} + +func normalizeHeaderRelativeError(s string) string { + const headerPrefix = "-: # " + if !strings.HasPrefix(s, headerPrefix) { + return s + } + pkgAndRest := strings.TrimPrefix(s, headerPrefix) + newline := strings.IndexByte(pkgAndRest, '\n') + if newline == -1 { + return s + } + pkg := strings.TrimSpace(pkgAndRest[:newline]) + rest := strings.TrimLeft(pkgAndRest[newline+1:], "\n") + if pkg == "" || rest == "" { + return s + } + + firstLineEnd := strings.IndexByte(rest, '\n') + if firstLineEnd == -1 { + firstLineEnd = len(rest) + } + firstLine := rest[:firstLineEnd] + rewritten, ok := canonicalizeRelativeErrorPath(pkg, firstLine) + if !ok { + return s + } + return normalizeLegacyUndefinedQualifiedName(rewritten + rest[firstLineEnd:]) +} + +func canonicalizeRelativeErrorPath(pkg, line string) (string, bool) { + goExt := strings.Index(line, ".go") + if goExt == -1 { + return "", false + } + goExt += len(".go") + linecol, n := scrubLineColumn(line[goExt:]) + if n == 0 { + return "", false + } + file := line[:goExt] + suffix := line[goExt+n:] + file = strings.ReplaceAll(file, "\\", "/") + file = strings.TrimPrefix(file, "./") + file = strings.TrimPrefix(file, "/") + baseDir := path.Base(pkg) + if strings.HasPrefix(file, pkg+"/") { + return file + linecol + suffix, true + } + if strings.HasPrefix(file, baseDir+"/") { + file = pkg + "/" + strings.TrimPrefix(file, baseDir+"/") + return file + linecol + suffix, true + } + if !strings.Contains(file, "/") { + return pkg + "/" + file + linecol + suffix, true + } + return "", false +} + +func normalizeLegacyUndefinedQualifiedName(s string) string { + const marker = ": undefined: " + idx := strings.Index(s, marker) + if idx == -1 { + return s + } + qualified := s[idx+len(marker):] + end := len(qualified) + for i, r := range qualified { + if r == '\n' || r == '\r' || r == '\t' || r == ' ' { + end = i + break + } + } + qualified = qualified[:end] + dot := strings.IndexByte(qualified, '.') + if dot == -1 || dot == 0 || dot == len(qualified)-1 { + return s + } + pkgName := qualified[:dot] + name := qualified[dot+1:] + if name == "" || !isLowerIdent(name) { + return s + } + return s[:idx] + ": name " + name + " not exported by package " + pkgName +} + +func isLowerIdent(s string) bool { + if s == "" { + return false + } + for i, r := range s { + if i == 0 && !unicode.IsLower(r) { + return false + } + if !unicode.IsLetter(r) && !unicode.IsDigit(r) && r != '_' { + return false + } + } + return true } func scrubLineColumn(s string) (replacement string, n int) { @@ -571,6 +672,24 @@ func filterLegacyCompilerErrors(errs []string) []string { return filtered } +func TestScrubErrorCanonicalizesHeaderRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\nfoo/wire.go:26:33: not enough arguments in call to wire.InterfaceValue") + want := "example.com/foo/wire.go:x:y: not enough arguments in call to wire.InterfaceValue" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + +func TestScrubErrorCanonicalizesHeaderRootRelativePath(t *testing.T) { + const gopath = "/tmp/wire_test" + got := scrubError(gopath, "-: # example.com/foo\n/wire.go:27:17: name foo not exported by package bar") + want := "example.com/foo/wire.go:x:y: name foo not exported by package bar" + if got != want { + t.Fatalf("scrubError() = %q, want %q", got, want) + } +} + type testCase struct { name string pkg string From e7cfc6390e19fee8485fbac62eb4d7eb352134bc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 05:38:30 -0500 Subject: [PATCH 09/82] feat: external loader caching --- cmd/wire/main.go | 4 + internal/loader/artifact_cache.go | 139 ++++++ internal/loader/custom.go | 268 +++++++--- internal/loader/discovery.go | 1 + internal/loader/loader_test.go | 791 ++++++++++++++++++++++++++++++ internal/loader/timing.go | 15 + internal/wire/wire_test.go | 115 +++++ 7 files changed, 1272 insertions(+), 61 deletions(-) create mode 100644 internal/loader/artifact_cache.go diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 4426ee1..f7fd92f 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -186,6 +186,10 @@ func withTiming(ctx context.Context, enabled bool) context.Context { return ctx } return wire.WithTiming(ctx, func(label string, dur time.Duration) { + if dur == 0 && strings.Contains(label, "=") { + log.Printf("timing: %s", label) + return + } log.Printf("timing: %s=%s", label, dur) }) } diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go new file mode 100644 index 0000000..d495143 --- /dev/null +++ b/internal/loader/artifact_cache.go @@ -0,0 +1,139 @@ +// Copyright 2026 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "go/token" + "go/types" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/tools/go/gcexportdata" +) + +const ( + loaderArtifactEnv = "WIRE_LOADER_ARTIFACTS" + loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" +) + +func loaderArtifactEnabled(env []string) bool { + return envValue(env, loaderArtifactEnv) == "1" +} + +func loaderArtifactDir(env []string) (string, error) { + if dir := envValue(env, loaderArtifactDirEnv); dir != "" { + return dir, nil + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "loader-artifacts"), nil +} + +func loaderArtifactPath(env []string, meta *packageMeta, isLocal bool) (string, error) { + dir, err := loaderArtifactDir(env) + if err != nil { + return "", err + } + key, err := loaderArtifactKey(meta, isLocal) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".bin"), nil +} + +func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-loader-artifact-v3\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.ImportPath)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(meta.Name)) + sum.Write([]byte{'\n'}) + if !isLocal { + sum.Write([]byte(meta.Export)) + sum.Write([]byte{'\n'}) + if meta.Error != nil { + sum.Write([]byte(meta.Error.Err)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil + } + for _, name := range metaFiles(meta) { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func readLoaderArtifact(path string, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return readLoaderArtifactData(data, fset, imports, pkgPath) +} + +func readLoaderArtifactData(data []byte, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { + return gcexportdata.Read(bytes.NewReader(data), fset, imports, pkgPath) +} + +func writeLoaderArtifact(path string, fset *token.FileSet, pkg *types.Package) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + return err + } + return os.WriteFile(path, out.Bytes(), 0o644) +} + +func artifactUpToDate(env []string, artifactPath string, meta *packageMeta, isLocal bool) bool { + _, err := os.Stat(artifactPath) + return err == nil +} + +func isProviderSetTypeForLoader(t types.Type) bool { + named, ok := t.(*types.Named) + if !ok { + return false + } + obj := named.Obj() + if obj == nil || obj.Pkg() == nil { + return false + } + switch obj.Pkg().Path() { + case "github.com/goforj/wire", "github.com/google/wire": + return obj.Name() == "ProviderSet" + default: + return false + } +} diff --git a/internal/loader/custom.go b/internal/loader/custom.go index ffa2d48..49fc217 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -68,35 +68,45 @@ type customValidator struct { } type customTypedGraphLoader struct { - workspace string - ctx context.Context - fset *token.FileSet - meta map[string]*packageMeta - targets map[string]struct{} - parseFile ParseFileFunc - packages map[string]*packages.Package - typesPkgs map[string]*types.Package - importer types.Importer - loading map[string]bool - stats typedLoadStats + workspace string + ctx context.Context + env []string + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + isLocalCache map[string]bool + stats typedLoadStats } type typedLoadStats struct { - read time.Duration - parse time.Duration - typecheck time.Duration - localRead time.Duration - externalRead time.Duration - localParse time.Duration - externalParse time.Duration - localTypecheck time.Duration - externalTypecheck time.Duration - filesRead int - packages int - localPackages int - externalPackages int - localFilesRead int - externalFilesRead int + read time.Duration + parse time.Duration + typecheck time.Duration + localRead time.Duration + externalRead time.Duration + localParse time.Duration + externalParse time.Duration + localTypecheck time.Duration + externalTypecheck time.Duration + filesRead int + packages int + localPackages int + externalPackages int + localFilesRead int + externalFilesRead int + artifactRead time.Duration + artifactPath time.Duration + artifactDecode time.Duration + artifactImportLink time.Duration + artifactWrite time.Duration + artifactHits int + artifactMisses int + artifactWrites int } func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { @@ -195,8 +205,8 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ - Packages: roots, - Backend: ModeCustom, + Packages: roots, + Backend: ModeCustom, Discovery: discoverySnapshotForMeta(meta, req.NeedDeps), }, nil } @@ -235,16 +245,19 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + stats: typedLoadStats{}, } root, err := l.loadPackage(req.Package) if err != nil { @@ -259,6 +272,14 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) + logDuration(ctx, "loader.custom.lazy.artifact_read", l.stats.artifactRead) + logDuration(ctx, "loader.custom.lazy.artifact_path", l.stats.artifactPath) + logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) + logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) + logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) + logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) + logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) + logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) return &LazyLoadResult{ Packages: []*packages.Package{root}, Backend: ModeCustom, @@ -294,16 +315,19 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + stats: typedLoadStats{}, } roots := make([]*packages.Package, 0, len(targets)) for _, m := range meta { @@ -326,6 +350,14 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) + logDuration(ctx, "loader.custom.typed.artifact_read", l.stats.artifactRead) + logDuration(ctx, "loader.custom.typed.artifact_path", l.stats.artifactPath) + logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) + logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) + logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) + logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) + logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) + logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, @@ -435,30 +467,51 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error } func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, error) { - if pkg := l.packages[path]; pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + if path == "C" { + if pkg := l.packages[path]; pkg != nil { + return pkg, nil + } + tpkg := l.typesPkgs[path] + if tpkg == nil { + tpkg = types.NewPackage("C", "C") + l.typesPkgs[path] = tpkg + } + pkg := &packages.Package{ + ID: "C", + Name: "C", + PkgPath: "C", + Fset: l.fset, + Imports: make(map[string]*packages.Package), + Types: tpkg, + } + l.packages[path] = pkg return pkg, nil } meta := l.meta[path] if meta == nil { - return nil, unsupportedError{reason: "missing lazy-load metadata"} + return nil, unsupportedError{reason: "missing lazy-load metadata for " + path} } + pkg := l.packages[path] if l.loading[path] { - if pkg := l.packages[path]; pkg != nil { + if pkg != nil { return pkg, nil } return nil, unsupportedError{reason: "lazy-load cycle"} } + if pkg != nil && (pkg.Types != nil || len(pkg.Errors) > 0) { + return pkg, nil + } l.loading[path] = true defer delete(l.loading, path) l.stats.packages++ - isLocal := isWorkspacePackage(l.workspace, meta.Dir) + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) if isLocal { l.stats.localPackages++ } else { l.stats.externalPackages++ } - pkg := l.packages[path] if pkg == nil { pkg = &packages.Package{ ID: meta.ImportPath, @@ -472,6 +525,28 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } + useArtifact := loaderArtifactEnabled(l.env) && !isTarget && !isLocal + if useArtifact { + if typed, ok := l.readArtifact(path, meta, isLocal); ok { + linkStart := time.Now() + for _, imp := range meta.Imports { + target := imp + if mapped := meta.ImportMap[imp]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + return nil, err + } + pkg.Imports[imp] = dep + } + l.stats.artifactImportLink += time.Since(linkStart) + pkg.Types = typed + pkg.TypesInfo = nil + pkg.Syntax = nil + return pkg, nil + } + } files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) pkg.Errors = append(pkg.Errors, parseErrs...) if len(files) == 0 { @@ -486,12 +561,11 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } tpkg := l.typesPkgs[path] - if tpkg == nil { + if tpkg == nil || tpkg.Complete() || (tpkg.Scope() != nil && len(tpkg.Scope().Names()) > 0) { tpkg = types.NewPackage(meta.ImportPath, meta.Name) l.typesPkgs[path] = tpkg } - _, isTarget := l.targets[path] - needFullState := isTarget || isWorkspacePackage(l.workspace, meta.Dir) + needFullState := isTarget || isLocal var info *types.Info if needFullState { info = &types.Info{ @@ -506,7 +580,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er var typeErrors []packages.Error cfg := &types.Config{ Sizes: types.SizesFor("gc", runtime.GOARCH), - IgnoreFuncBodies: !isWorkspacePackage(l.workspace, meta.Dir), + IgnoreFuncBodies: !isLocal, Importer: importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { return types.Unsafe, nil @@ -537,7 +611,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } checker := types.NewChecker(cfg, l.fset, tpkg, info) typecheckStart := time.Now() - if err := checker.Files(files); err != nil && len(typeErrors) == 0 { + if err := l.checkFiles(path, checker, files); err != nil && len(typeErrors) == 0 { typeErrors = append(typeErrors, toPackagesError(l.fset, err)) } typecheckDuration := time.Since(typecheckStart) @@ -555,9 +629,84 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) + if shouldWriteArtifact(l.env, isTarget, isLocal) && len(pkg.Errors) == 0 { + _ = l.writeArtifact(meta, tpkg, isLocal) + } return pkg, nil } +func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { + defer func() { + if r := recover(); r != nil { + err = unsupportedError{reason: fmt.Sprintf("typecheck panic in %s: %v", path, r)} + } + }() + return checker.Files(files) +} + +func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, isLocal bool) (*types.Package, bool) { + start := time.Now() + pathStart := time.Now() + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + l.stats.artifactPath += time.Since(pathStart) + if err != nil { + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + var tpkg *types.Package + decodeStart := time.Now() + tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + l.stats.artifactDecode += time.Since(decodeStart) + if err != nil { + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + l.stats.artifactRead += time.Since(start) + l.stats.artifactHits++ + l.typesPkgs[path] = tpkg + return tpkg, true +} + +func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Package, isLocal bool) error { + start := time.Now() + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + if err != nil { + l.stats.artifactWrite += time.Since(start) + return err + } + if artifactUpToDate(l.env, artifactPath, meta, isLocal) { + l.stats.artifactWrite += time.Since(start) + return nil + } + writeErr := writeLoaderArtifact(artifactPath, l.fset, pkg) + l.stats.artifactWrite += time.Since(start) + if writeErr == nil { + l.stats.artifactWrites++ + } + if writeErr != nil { + return writeErr + } + return nil +} + +func shouldWriteArtifact(env []string, isTarget, isLocal bool) bool { + if !loaderArtifactEnabled(env) || isTarget || isLocal { + return false + } + return true +} + +func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *packageMeta) bool { + if local, ok := l.isLocalCache[importPath]; ok { + return local + } + local := isWorkspacePackage(l.workspace, meta.Dir) + l.isLocalCache[importPath] = local + return local +} + func (v *customValidator) importFromExport(path string) (*types.Package, error) { if typed := v.packages[path]; typed != nil && typed.Complete() { return typed, nil @@ -907,8 +1056,6 @@ func isWorkspacePackage(workspaceRoot, dir string) bool { if workspaceRoot == "" || dir == "" { return false } - workspaceRoot = canonicalLoaderPath(workspaceRoot) - dir = canonicalLoaderPath(dir) if dir == workspaceRoot { return true } @@ -970,7 +1117,6 @@ func envValue(env []string, key string) string { return "" } - func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { if meta == nil { return nil diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index a6aba46..9adbf4e 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -73,6 +73,7 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, if meta.ImportPath == "" { continue } + meta.Dir = canonicalLoaderPath(meta.Dir) for i, name := range meta.GoFiles { if !filepath.IsAbs(name) { meta.GoFiles[i] = filepath.Join(meta.Dir, name) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 0e38d99..50d8690 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -15,16 +15,22 @@ package loader import ( + "bytes" "context" + "fmt" "go/ast" "go/parser" "go/token" + "go/types" "os" "path/filepath" "sort" + "strconv" "strings" "testing" + "time" + "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" ) @@ -351,6 +357,107 @@ func TestMetaFilesFallsBackToGoFiles(t *testing.T) { } } +func TestExportDataPairings(t *testing.T) { + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "lib.go", "package lib\n\ntype T int\n", 0) + if err != nil { + t.Fatalf("ParseFile() error = %v", err) + } + pkg, err := new(types.Config).Check("lib", fset, []*ast.File{file}, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + t.Run("gcexportdata write/read direct", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), make(map[string]*types.Package), pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + if got.Scope().Lookup("T") == nil { + t.Fatal("reimported package missing T") + } + }) + + t.Run("gcexportdata write with newreader fails", func(t *testing.T) { + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + if _, err := gcexportdata.NewReader(bytes.NewReader(out.Bytes())); err == nil { + t.Fatal("gcexportdata.NewReader() unexpectedly succeeded on direct gcexportdata.Write output") + } + }) +} + +func TestExportDataRoundTripWithImports(t *testing.T) { + fset := token.NewFileSet() + depPkg, err := new(types.Config).Check("example.com/dep", fset, []*ast.File{ + mustParseFile(t, fset, "dep.go", `package dep + +type T int +`), + }, nil) + if err != nil { + t.Fatalf("types.Check(dep) error = %v", err) + } + pkg, err := (&types.Config{ + Importer: importerFuncForTest(func(path string) (*types.Package, error) { + if path == "example.com/dep" { + return depPkg, nil + } + if path == "unsafe" { + return types.Unsafe, nil + } + return nil, nil + }), + }).Check("example.com/lib", fset, []*ast.File{ + mustParseFile(t, fset, "lib.go", `package lib + +import "example.com/dep" + +type T struct { + S dep.T +} +`), + }, nil) + if err != nil { + t.Fatalf("types.Check() error = %v", err) + } + + var out bytes.Buffer + if err := gcexportdata.Write(&out, fset, pkg); err != nil { + t.Fatalf("gcexportdata.Write() error = %v", err) + } + imports := make(map[string]*types.Package) + got, err := gcexportdata.Read(bytes.NewReader(out.Bytes()), token.NewFileSet(), imports, pkg.Path()) + if err != nil { + t.Fatalf("gcexportdata.Read() error = %v", err) + } + obj := got.Scope().Lookup("T") + if obj == nil { + t.Fatal("reimported package missing T") + } + named, ok := obj.Type().(*types.Named) + if !ok { + t.Fatalf("T type = %T, want *types.Named", obj.Type()) + } + field := named.Underlying().(*types.Struct).Field(0) + if field.Type().String() != "example.com/dep.T" { + t.Fatalf("field type = %q, want %q", field.Type().String(), "example.com/dep.T") + } + depImport := imports["example.com/dep"] + if depImport == nil { + t.Fatal("imports map missing dep") + } + if depImport.Scope().Lookup("T") == nil { + t.Fatal("dep import missing T after import") + } +} + func TestLoadTypedPackageGraphFallback(t *testing.T) { root := t.TempDir() writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") @@ -494,6 +601,644 @@ func TestLoadTypedPackageGraphCustomKeepsExternalPackagesLight(t *testing.T) { } } +func TestLoadTypedPackageGraphCustomExternalArtifactCache(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + rootPkg := got.Packages[0] + if rootPkg.Imports["fmt"] == nil { + t.Fatal("expected fmt import package") + } + return parseCalls + } + + first := run() + entries, err := os.ReadDir(artifactDir) + if err != nil { + t.Fatalf("ReadDir(%q) error = %v", artifactDir, err) + } + if len(entries) == 0 { + t.Fatal("expected artifact cache files after first run") + } + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"fmt\"\n\nfunc Init() string { return fmt.Sprint(\"ok\") }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() []string { + var labels []string + ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { + labels = append(labels, label) + }) + l := New() + _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return labels + } + + _ = run() + second := run() + if !hasPrefixLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run labels missing artifact hit count: %v", second) + } + if !containsPositiveIntLabel(second, "loader.custom.lazy.artifact_hits=") { + t.Fatalf("second run artifact hit count was not positive: %v", second) + } +} + +func TestLoadTypedPackageGraphCustomLeafLocalArtifactCache(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactCacheWithoutProviderSets(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport \"example.com/app/leaf\"\n\nfunc Provide() string { return leaf.Provide() }\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second >= first { + t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) + } +} + +func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactDisabledForProviderSets(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport (\n\t\"example.com/app/leaf\"\n\t\"github.com/goforj/wire\"\n)\n\nfunc Provide() string { return leaf.Provide() }\n\nvar Set = wire.NewSet(Provide)\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second < first-1 { + t.Fatalf("second parseCalls = %d, expected provider-set package to stay near first run %d", second, first) + } + meta, err := runGoList(context.Background(), goListRequest{ + WD: root, + Env: env, + Patterns: []string{"./app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/app/dep"] + if depMeta == nil { + t.Fatal("missing metadata for example.com/app/dep") + } + hasProviderSets, ok := readLocalArtifactProviderSetFlag(env, depMeta) + if ok && !hasProviderSets { + t.Fatal("expected provider-set package metadata to record provider sets") + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForProviderSetImporter(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "cmd", "cmd.go"), "package cmd\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar Set = wire.NewSet(jobs.Provide)\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cmd\"\n\nfunc Init() string { return \"ok\" }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() (int, []string) { + var ( + parseCalls int + labels []string + ) + ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { + labels = append(labels, label) + }) + l := New() + _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls, labels + } + + _, _ = run() + secondCalls, _ := run() + if secondCalls < 2 { + t.Fatalf("second parseCalls = %d, expected source load for cmd and jobs", secondCalls) + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForWireDeclImporter(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + repoRoot, err := os.Getwd() + if err != nil { + t.Fatalf("Getwd() error = %v", err) + } + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") + writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "cfg", "cfg.go"), "package cfg\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar V = wire.Value(jobs.Provide())\n") + writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cfg\"\n\nfunc Init() string { return \"ok\" }\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + run := func() int { + var parseCalls int + l := New() + _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + parseCalls++ + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) + } + return parseCalls + } + + first := run() + second := run() + if second < first-1 { + t.Fatalf("second parseCalls = %d, expected source load for cfg and jobs to remain near first run %d", second, first) + } + meta, err := runGoList(context.Background(), goListRequest{ + WD: root, + Env: env, + Patterns: []string{"./app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + cfgMeta := meta["example.com/app/cfg"] + if cfgMeta == nil { + t.Fatal("missing metadata for example.com/app/cfg") + } + flags, ok := readLocalArtifactFlags(env, cfgMeta) + if ok && !flags.wireDecls { + t.Fatal("expected wire decl package metadata to record wire declarations") + } +} + +func TestLoadTypedPackageGraphCustomLocalArtifactPreservesImportedPackageName(t *testing.T) { + root := t.TempDir() + artifactDir := t.TempDir() + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "models", "models.go"), "package models\n\nfunc NewRepo() string { return \"ok\" }\n") + writeTestFile(t, filepath.Join(root, "root", "wire.go"), "package root\n\nimport \"example.com/app/models\"\n\nvar _ = models.NewRepo\n") + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + l := New() + load := func() (*LazyLoadResult, error) { + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "example.com/app/root", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + + first, err := load() + if err != nil { + t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) + } + second, err := load() + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + firstRoot := collectGraph(first.Packages)["example.com/app/root"] + secondRoot := collectGraph(second.Packages)["example.com/app/root"] + if firstRoot == nil || secondRoot == nil { + t.Fatal("missing root package") + } + firstModels := firstRoot.Imports["example.com/app/models"] + secondModels := secondRoot.Imports["example.com/app/models"] + if firstModels == nil || secondModels == nil { + t.Fatal("missing imported models package") + } + if firstModels.Types == nil || secondModels.Types == nil { + t.Fatal("expected imported models package to be typed") + } + if firstModels.Types.Name() != "models" { + t.Fatalf("first imported package name = %q, want %q", firstModels.Types.Name(), "models") + } + if secondModels.Types.Name() != "models" { + t.Fatalf("second imported package name = %q, want %q", secondModels.Types.Name(), "models") + } +} + +func TestLoadTypedPackageGraphCustomRealAppDirectImporterBoundarySelectors(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + env := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderLocalArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + "WIRE_LOADER_LOCAL_BOUNDARY=direct_importers", + ) + load := func() (*LazyLoadResult, error) { + l := New() + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: root, + Env: env, + Package: "test/wire", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + if _, err := load(); err != nil { + t.Fatalf("warm LoadTypedPackageGraph(custom) error = %v", err) + } + got, err := load() + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + graph := collectGraph(got.Packages) + rootPkg := graph["test/wire"] + if rootPkg == nil { + t.Fatal("missing root package test/wire") + } + checkSelector := func(fileSuffix, pkgIdentName string) { + t.Helper() + var targetFile *ast.File + for _, f := range rootPkg.Syntax { + name := rootPkg.Fset.File(f.Pos()).Name() + if strings.HasSuffix(name, fileSuffix) { + targetFile = f + break + } + } + if targetFile == nil { + t.Fatalf("missing syntax file %s", fileSuffix) + } + found := false + ast.Inspect(targetFile, func(node ast.Node) bool { + sel, ok := node.(*ast.SelectorExpr) + if !ok { + return true + } + pkgIdent, ok := sel.X.(*ast.Ident) + if !ok || pkgIdent.Name != pkgIdentName { + return true + } + found = true + pkgObj, ok := rootPkg.TypesInfo.ObjectOf(pkgIdent).(*types.PkgName) + if !ok || pkgObj == nil { + var importBindings []string + for _, spec := range targetFile.Imports { + obj := rootPkg.TypesInfo.Implicits[spec] + path, _ := strconv.Unquote(spec.Path.Value) + name := "" + if spec.Name != nil { + name = spec.Name.Name + } + switch typed := obj.(type) { + case *types.PkgName: + importBindings = append(importBindings, fmt.Sprintf("%s=>%s(%s)", name, typed.Imported().Path(), typed.Imported().Name())) + case nil: + importBindings = append(importBindings, fmt.Sprintf("%s=>nil[%s]", name, path)) + default: + importBindings = append(importBindings, fmt.Sprintf("%s=>%T[%s]", name, obj, path)) + } + } + importPath := "" + for _, spec := range targetFile.Imports { + path, _ := strconv.Unquote(spec.Path.Value) + name := filepath.Base(path) + if spec.Name != nil { + name = spec.Name.Name + } + if name == pkgIdentName { + importPath = path + break + } + } + var depSummary string + if importPath != "" { + if dep := graph[importPath]; dep != nil { + depSummary = fmt.Sprintf("dep=%s name=%q types=%v typeName=%q errors=%v", importPath, dep.Name, dep.Types != nil, func() string { + if dep.Types == nil { + return "" + } + return dep.Types.Name() + }(), dep.Errors) + } else { + depSummary = "dep_missing=" + importPath + } + } + t.Fatalf("%s selector lost package object for %s; imports=%s; importPath=%q; %s; root errors=%v", fileSuffix, pkgIdentName, strings.Join(importBindings, ", "), importPath, depSummary, rootPkg.Errors) + } + if rootPkg.TypesInfo.ObjectOf(sel.Sel) == nil { + t.Fatalf("%s selector lost object for %s.%s", fileSuffix, pkgIdentName, sel.Sel.Name) + } + return false + }) + if !found { + t.Fatalf("did not find selector using %s in %s", pkgIdentName, fileSuffix) + } + } + checkSelector("inject_repositories.go", "models") + checkSelector("inject_http.go", "http") + if len(rootPkg.Errors) > 0 { + var msgs []string + for _, err := range rootPkg.Errors { + msgs = append(msgs, err.Msg) + } + t.Fatalf("root package has errors under direct importer boundary: %s", strings.Join(msgs, "; ")) + } + for _, p := range []string{"test/internal/models", "test/internal/http"} { + dep := graph[p] + if dep == nil { + t.Fatalf("missing dependency package %s", p) + } + if dep.Types == nil { + t.Fatalf("dependency %s missing types", p) + } + if dep.Name == "" || dep.Types.Name() == "" { + t.Fatalf("dependency %s missing package name", p) + } + if dep.Name != dep.Types.Name() { + t.Fatalf("dependency %s package name mismatch: pkg=%q types=%q", p, dep.Name, dep.Types.Name()) + } + } + _ = fmt.Sprintf +} + +func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + load := func(env []string) (map[string]*packages.Package, error) { + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: root, + Env: env, + Patterns: []string{"."}, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: ModeCustom, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + return nil, err + } + return collectGraph(got.Packages), nil + } + + base, err := load(os.Environ()) + if err != nil { + t.Fatalf("base load error = %v", err) + } + withArtifactsEnv := append(os.Environ(), + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + firstArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("first artifact load error = %v", err) + } + secondArtifact, err := load(withArtifactsEnv) + if err != nil { + t.Fatalf("second artifact load error = %v", err) + } + if len(base) != len(firstArtifact) { + t.Fatalf("first artifact graph size = %d, want %d", len(firstArtifact), len(base)) + } + if len(base) != len(secondArtifact) { + var missing []string + for path := range base { + if secondArtifact[path] == nil { + missing = append(missing, path) + } + } + sort.Strings(missing) + parents := make(map[string][]string) + for parentPath, pkg := range base { + for impPath := range pkg.Imports { + if secondArtifact[impPath] == nil { + parents[impPath] = append(parents[impPath], parentPath) + } + } + } + parentSummary := make([]string, 0, 5) + for _, path := range missing { + if len(parentSummary) == 5 { + break + } + importers := append([]string(nil), parents[path]...) + sort.Strings(importers) + if len(importers) > 3 { + importers = importers[:3] + } + parentSummary = append(parentSummary, path+" <- "+strings.Join(importers, ",")) + } + if len(missing) > 20 { + missing = missing[:20] + } + secondParent := secondArtifact["github.com/shirou/gopsutil/v4/internal/common"] + secondParentImports := []string(nil) + if secondParent != nil { + secondParentImports = sortedImportPaths(secondParent.Imports) + } + internalCommonParents := append([]string(nil), parents["github.com/shirou/gopsutil/v4/internal/common"]...) + sort.Strings(internalCommonParents) + t.Fatalf("second artifact graph size = %d, want %d; missing sample=%v; parent sample=%v; gopsutil/internal/common parents=%v; gopsutil/internal/common imports on second run=%v", len(secondArtifact), len(base), missing, parentSummary, internalCommonParents, secondParentImports) + } + if compiledFileCount(base) != compiledFileCount(secondArtifact) { + t.Fatalf("second artifact compiled file count = %d, want %d", compiledFileCount(secondArtifact), compiledFileCount(base)) + } +} + func TestLoadRootGraphCustomMatchesFallback(t *testing.T) { root := t.TempDir() writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") @@ -702,6 +1447,14 @@ func collectGraph(roots []*packages.Package) map[string]*packages.Package { return out } +func compiledFileCount(pkgs map[string]*packages.Package) int { + total := 0 + for _, pkg := range pkgs { + total += len(pkg.CompiledGoFiles) + } + return total +} + func equalStrings(a, b []string) bool { if len(a) != len(b) { return false @@ -737,6 +1490,21 @@ func sortedImportPaths(m map[string]*packages.Package) []string { return out } +type importerFuncForTest func(string) (*types.Package, error) + +func (f importerFuncForTest) Import(path string) (*types.Package, error) { + return f(path) +} + +func mustParseFile(t *testing.T, fset *token.FileSet, filename, src string) *ast.File { + t.Helper() + file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + if err != nil { + t.Fatalf("ParseFile(%q) error = %v", filename, err) + } + return file +} + func normalizePathForCompare(path string) string { if path == "" { return "" @@ -771,6 +1539,29 @@ func comparableErrors(errs []packages.Error) []string { return out } +func hasPrefixLabel(labels []string, prefix string) bool { + for _, label := range labels { + if strings.HasPrefix(label, prefix) { + return true + } + } + return false +} + +func containsPositiveIntLabel(labels []string, prefix string) bool { + for _, label := range labels { + if !strings.HasPrefix(label, prefix) { + continue + } + value := strings.TrimPrefix(label, prefix) + n, err := strconv.Atoi(value) + if err == nil && n > 0 { + return true + } + } + return false +} + func normalizeErrorPos(pos string) string { if pos == "" || pos == "-" { return pos diff --git a/internal/loader/timing.go b/internal/loader/timing.go index 0211f17..4b902db 100644 --- a/internal/loader/timing.go +++ b/internal/loader/timing.go @@ -2,6 +2,8 @@ package loader import ( "context" + "fmt" + "log" "time" ) @@ -39,3 +41,16 @@ func logDuration(ctx context.Context, label string, d time.Duration) { t(label, d) } } + +func logInt(ctx context.Context, label string, v int) { + if t := timing(ctx); t != nil { + t(fmt.Sprintf("%s=%d", label, v), 0) + } +} + +func debugf(ctx context.Context, format string, args ...interface{}) { + if timing(ctx) == nil { + return + } + log.Printf("timing: "+format, args...) +} diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index dc5cfda..7bd6c20 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -28,6 +28,7 @@ import ( "os/exec" "path" "path/filepath" + "sort" "strings" "testing" "unicode" @@ -220,6 +221,120 @@ func TestGenerateResultCommitWithStatus(t *testing.T) { } } +func TestGenerateRealAppArtifactParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + ctx := context.Background() + + run := func(env []string) ([]GenerateResult, []string) { + t.Helper() + gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) + errStrings := make([]string, len(errs)) + for i, err := range errs { + errStrings[i] = err.Error() + } + sort.Strings(errStrings) + return gens, errStrings + } + + baseGens, baseErrs := run(os.Environ()) + artifactEnv := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + _, warmErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, warmErrs); diff != "" { + t.Fatalf("artifact warm-up errors mismatch (-base +warm):\n%s", diff) + } + artifactGens, artifactErrs := run(artifactEnv) + if diff := cmp.Diff(baseErrs, artifactErrs); diff != "" { + t.Fatalf("artifact errors mismatch (-base +artifact):\n%s", diff) + } + if len(baseGens) != len(artifactGens) { + t.Fatalf("generated file count = %d, want %d", len(artifactGens), len(baseGens)) + } + for i := range baseGens { + if baseGens[i].PkgPath != artifactGens[i].PkgPath { + t.Fatalf("generated package[%d] = %q, want %q", i, artifactGens[i].PkgPath, baseGens[i].PkgPath) + } + if diff := cmp.Diff(string(baseGens[i].Content), string(artifactGens[i].Content)); diff != "" { + t.Fatalf("generated content mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) + artifactGenErrs := comparableGenerateErrors(artifactGens[i].Errs) + if diff := cmp.Diff(baseGenErrs, artifactGenErrs); diff != "" { + t.Fatalf("generate errs mismatch for %q (-base +artifact):\n%s", baseGens[i].PkgPath, diff) + } + } +} + +func TestGenerateRealAppSelfOnlyArtifactParity(t *testing.T) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + t.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := t.TempDir() + ctx := context.Background() + + run := func(env []string) ([]GenerateResult, []string) { + t.Helper() + gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) + errStrings := make([]string, len(errs)) + for i, err := range errs { + errStrings[i] = err.Error() + } + sort.Strings(errStrings) + return gens, errStrings + } + + artifactEnv := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_LOCAL_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + _, warmErrs := run(artifactEnv) + if len(warmErrs) > 0 { + t.Fatalf("artifact warm-up errors: %v", warmErrs) + } + baseGens, baseErrs := run(artifactEnv) + + selfOnlyEnv := append(append([]string(nil), artifactEnv...), + "WIRE_LOADER_LOCAL_BOUNDARY=self_only", + ) + selfOnlyGens, selfOnlyErrs := run(selfOnlyEnv) + if diff := cmp.Diff(baseErrs, selfOnlyErrs); diff != "" { + t.Fatalf("self_only errors mismatch (-base +self_only):\n%s", diff) + } + if len(baseGens) != len(selfOnlyGens) { + t.Fatalf("generated file count = %d, want %d", len(selfOnlyGens), len(baseGens)) + } + for i := range baseGens { + if baseGens[i].PkgPath != selfOnlyGens[i].PkgPath { + t.Fatalf("generated package[%d] = %q, want %q", i, selfOnlyGens[i].PkgPath, baseGens[i].PkgPath) + } + if diff := cmp.Diff(string(baseGens[i].Content), string(selfOnlyGens[i].Content)); diff != "" { + t.Fatalf("generated content mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) + } + baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) + selfOnlyGenErrs := comparableGenerateErrors(selfOnlyGens[i].Errs) + if diff := cmp.Diff(baseGenErrs, selfOnlyGenErrs); diff != "" { + t.Fatalf("generate errs mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) + } + } +} + +func comparableGenerateErrors(errs []error) []string { + out := make([]string, len(errs)) + for i, err := range errs { + out[i] = err.Error() + } + sort.Strings(out) + return out +} + func TestZeroValue(t *testing.T) { t.Parallel() From 93796593f84ca7cd5cc127f9a1fd8bc989df0f35 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 05:44:08 -0500 Subject: [PATCH 10/82] chore: remove local caching strat --- internal/loader/loader_test.go | 457 --------------------------------- internal/wire/wire_test.go | 55 ---- 2 files changed, 512 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 50d8690..0c871e0 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -17,7 +17,6 @@ package loader import ( "bytes" "context" - "fmt" "go/ast" "go/parser" "go/token" @@ -694,462 +693,6 @@ func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing. } } -func TestLoadTypedPackageGraphCustomLeafLocalArtifactCache(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second >= first { - t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) - } -} - -func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactCacheWithoutProviderSets(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport \"example.com/app/leaf\"\n\nfunc Provide() string { return leaf.Provide() }\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second >= first { - t.Fatalf("second parseCalls = %d, want less than first run %d", second, first) - } -} - -func TestLoadTypedPackageGraphCustomNonLeafLocalArtifactDisabledForProviderSets(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "leaf", "leaf.go"), "package leaf\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "dep", "dep.go"), "package dep\n\nimport (\n\t\"example.com/app/leaf\"\n\t\"github.com/goforj/wire\"\n)\n\nfunc Provide() string { return leaf.Provide() }\n\nvar Set = wire.NewSet(Provide)\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport \"example.com/app/dep\"\n\nfunc Init() string { return dep.Provide() }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second < first-1 { - t.Fatalf("second parseCalls = %d, expected provider-set package to stay near first run %d", second, first) - } - meta, err := runGoList(context.Background(), goListRequest{ - WD: root, - Env: env, - Patterns: []string{"./app"}, - NeedDeps: true, - }) - if err != nil { - t.Fatalf("runGoList() error = %v", err) - } - depMeta := meta["example.com/app/dep"] - if depMeta == nil { - t.Fatal("missing metadata for example.com/app/dep") - } - hasProviderSets, ok := readLocalArtifactProviderSetFlag(env, depMeta) - if ok && !hasProviderSets { - t.Fatal("expected provider-set package metadata to record provider sets") - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForProviderSetImporter(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "cmd", "cmd.go"), "package cmd\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar Set = wire.NewSet(jobs.Provide)\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cmd\"\n\nfunc Init() string { return \"ok\" }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() (int, []string) { - var ( - parseCalls int - labels []string - ) - ctx := WithTiming(context.Background(), func(label string, _ time.Duration) { - labels = append(labels, label) - }) - l := New() - _, err := l.LoadTypedPackageGraph(ctx, LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls, labels - } - - _, _ = run() - secondCalls, _ := run() - if secondCalls < 2 { - t.Fatalf("second parseCalls = %d, expected source load for cmd and jobs", secondCalls) - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactDisabledForWireDeclImporter(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - repoRoot, err := os.Getwd() - if err != nil { - t.Fatalf("Getwd() error = %v", err) - } - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n\nrequire github.com/goforj/wire v0.0.0\n\nreplace github.com/goforj/wire => "+repoRoot+"\n") - writeTestFile(t, filepath.Join(root, "jobs", "jobs.go"), "package jobs\n\nfunc Provide() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "cfg", "cfg.go"), "package cfg\n\nimport (\n\t\"example.com/app/jobs\"\n\t\"github.com/goforj/wire\"\n)\n\nvar V = wire.Value(jobs.Provide())\n") - writeTestFile(t, filepath.Join(root, "app", "wire.go"), "package app\n\nimport _ \"example.com/app/cfg\"\n\nfunc Init() string { return \"ok\" }\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - run := func() int { - var parseCalls int - l := New() - _, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/app", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - parseCalls++ - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - if err != nil { - t.Fatalf("LoadTypedPackageGraph(custom) error = %v", err) - } - return parseCalls - } - - first := run() - second := run() - if second < first-1 { - t.Fatalf("second parseCalls = %d, expected source load for cfg and jobs to remain near first run %d", second, first) - } - meta, err := runGoList(context.Background(), goListRequest{ - WD: root, - Env: env, - Patterns: []string{"./app"}, - NeedDeps: true, - }) - if err != nil { - t.Fatalf("runGoList() error = %v", err) - } - cfgMeta := meta["example.com/app/cfg"] - if cfgMeta == nil { - t.Fatal("missing metadata for example.com/app/cfg") - } - flags, ok := readLocalArtifactFlags(env, cfgMeta) - if ok && !flags.wireDecls { - t.Fatal("expected wire decl package metadata to record wire declarations") - } -} - -func TestLoadTypedPackageGraphCustomLocalArtifactPreservesImportedPackageName(t *testing.T) { - root := t.TempDir() - artifactDir := t.TempDir() - writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") - writeTestFile(t, filepath.Join(root, "models", "models.go"), "package models\n\nfunc NewRepo() string { return \"ok\" }\n") - writeTestFile(t, filepath.Join(root, "root", "wire.go"), "package root\n\nimport \"example.com/app/models\"\n\nvar _ = models.NewRepo\n") - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - ) - - l := New() - load := func() (*LazyLoadResult, error) { - return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "example.com/app/root", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - } - - first, err := load() - if err != nil { - t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) - } - second, err := load() - if err != nil { - t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) - } - firstRoot := collectGraph(first.Packages)["example.com/app/root"] - secondRoot := collectGraph(second.Packages)["example.com/app/root"] - if firstRoot == nil || secondRoot == nil { - t.Fatal("missing root package") - } - firstModels := firstRoot.Imports["example.com/app/models"] - secondModels := secondRoot.Imports["example.com/app/models"] - if firstModels == nil || secondModels == nil { - t.Fatal("missing imported models package") - } - if firstModels.Types == nil || secondModels.Types == nil { - t.Fatal("expected imported models package to be typed") - } - if firstModels.Types.Name() != "models" { - t.Fatalf("first imported package name = %q, want %q", firstModels.Types.Name(), "models") - } - if secondModels.Types.Name() != "models" { - t.Fatalf("second imported package name = %q, want %q", secondModels.Types.Name(), "models") - } -} - -func TestLoadTypedPackageGraphCustomRealAppDirectImporterBoundarySelectors(t *testing.T) { - root := os.Getenv("WIRE_REAL_APP_ROOT") - if root == "" { - t.Skip("WIRE_REAL_APP_ROOT not set") - } - artifactDir := t.TempDir() - env := append(os.Environ(), - loaderArtifactEnv+"=1", - loaderLocalArtifactEnv+"=1", - loaderArtifactDirEnv+"="+artifactDir, - "WIRE_LOADER_LOCAL_BOUNDARY=direct_importers", - ) - load := func() (*LazyLoadResult, error) { - l := New() - return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ - WD: root, - Env: env, - Package: "test/wire", - Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, - LoaderMode: ModeCustom, - Fset: token.NewFileSet(), - ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { - return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) - }, - }) - } - if _, err := load(); err != nil { - t.Fatalf("warm LoadTypedPackageGraph(custom) error = %v", err) - } - got, err := load() - if err != nil { - t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) - } - graph := collectGraph(got.Packages) - rootPkg := graph["test/wire"] - if rootPkg == nil { - t.Fatal("missing root package test/wire") - } - checkSelector := func(fileSuffix, pkgIdentName string) { - t.Helper() - var targetFile *ast.File - for _, f := range rootPkg.Syntax { - name := rootPkg.Fset.File(f.Pos()).Name() - if strings.HasSuffix(name, fileSuffix) { - targetFile = f - break - } - } - if targetFile == nil { - t.Fatalf("missing syntax file %s", fileSuffix) - } - found := false - ast.Inspect(targetFile, func(node ast.Node) bool { - sel, ok := node.(*ast.SelectorExpr) - if !ok { - return true - } - pkgIdent, ok := sel.X.(*ast.Ident) - if !ok || pkgIdent.Name != pkgIdentName { - return true - } - found = true - pkgObj, ok := rootPkg.TypesInfo.ObjectOf(pkgIdent).(*types.PkgName) - if !ok || pkgObj == nil { - var importBindings []string - for _, spec := range targetFile.Imports { - obj := rootPkg.TypesInfo.Implicits[spec] - path, _ := strconv.Unquote(spec.Path.Value) - name := "" - if spec.Name != nil { - name = spec.Name.Name - } - switch typed := obj.(type) { - case *types.PkgName: - importBindings = append(importBindings, fmt.Sprintf("%s=>%s(%s)", name, typed.Imported().Path(), typed.Imported().Name())) - case nil: - importBindings = append(importBindings, fmt.Sprintf("%s=>nil[%s]", name, path)) - default: - importBindings = append(importBindings, fmt.Sprintf("%s=>%T[%s]", name, obj, path)) - } - } - importPath := "" - for _, spec := range targetFile.Imports { - path, _ := strconv.Unquote(spec.Path.Value) - name := filepath.Base(path) - if spec.Name != nil { - name = spec.Name.Name - } - if name == pkgIdentName { - importPath = path - break - } - } - var depSummary string - if importPath != "" { - if dep := graph[importPath]; dep != nil { - depSummary = fmt.Sprintf("dep=%s name=%q types=%v typeName=%q errors=%v", importPath, dep.Name, dep.Types != nil, func() string { - if dep.Types == nil { - return "" - } - return dep.Types.Name() - }(), dep.Errors) - } else { - depSummary = "dep_missing=" + importPath - } - } - t.Fatalf("%s selector lost package object for %s; imports=%s; importPath=%q; %s; root errors=%v", fileSuffix, pkgIdentName, strings.Join(importBindings, ", "), importPath, depSummary, rootPkg.Errors) - } - if rootPkg.TypesInfo.ObjectOf(sel.Sel) == nil { - t.Fatalf("%s selector lost object for %s.%s", fileSuffix, pkgIdentName, sel.Sel.Name) - } - return false - }) - if !found { - t.Fatalf("did not find selector using %s in %s", pkgIdentName, fileSuffix) - } - } - checkSelector("inject_repositories.go", "models") - checkSelector("inject_http.go", "http") - if len(rootPkg.Errors) > 0 { - var msgs []string - for _, err := range rootPkg.Errors { - msgs = append(msgs, err.Msg) - } - t.Fatalf("root package has errors under direct importer boundary: %s", strings.Join(msgs, "; ")) - } - for _, p := range []string{"test/internal/models", "test/internal/http"} { - dep := graph[p] - if dep == nil { - t.Fatalf("missing dependency package %s", p) - } - if dep.Types == nil { - t.Fatalf("dependency %s missing types", p) - } - if dep.Name == "" || dep.Types.Name() == "" { - t.Fatalf("dependency %s missing package name", p) - } - if dep.Name != dep.Types.Name() { - t.Fatalf("dependency %s package name mismatch: pkg=%q types=%q", p, dep.Name, dep.Types.Name()) - } - } - _ = fmt.Sprintf -} - func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go index 7bd6c20..23db303 100644 --- a/internal/wire/wire_test.go +++ b/internal/wire/wire_test.go @@ -271,61 +271,6 @@ func TestGenerateRealAppArtifactParity(t *testing.T) { } } -func TestGenerateRealAppSelfOnlyArtifactParity(t *testing.T) { - root := os.Getenv("WIRE_REAL_APP_ROOT") - if root == "" { - t.Skip("WIRE_REAL_APP_ROOT not set") - } - artifactDir := t.TempDir() - ctx := context.Background() - - run := func(env []string) ([]GenerateResult, []string) { - t.Helper() - gens, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}) - errStrings := make([]string, len(errs)) - for i, err := range errs { - errStrings[i] = err.Error() - } - sort.Strings(errStrings) - return gens, errStrings - } - - artifactEnv := append(os.Environ(), - "WIRE_LOADER_ARTIFACTS=1", - "WIRE_LOADER_LOCAL_ARTIFACTS=1", - "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, - ) - _, warmErrs := run(artifactEnv) - if len(warmErrs) > 0 { - t.Fatalf("artifact warm-up errors: %v", warmErrs) - } - baseGens, baseErrs := run(artifactEnv) - - selfOnlyEnv := append(append([]string(nil), artifactEnv...), - "WIRE_LOADER_LOCAL_BOUNDARY=self_only", - ) - selfOnlyGens, selfOnlyErrs := run(selfOnlyEnv) - if diff := cmp.Diff(baseErrs, selfOnlyErrs); diff != "" { - t.Fatalf("self_only errors mismatch (-base +self_only):\n%s", diff) - } - if len(baseGens) != len(selfOnlyGens) { - t.Fatalf("generated file count = %d, want %d", len(selfOnlyGens), len(baseGens)) - } - for i := range baseGens { - if baseGens[i].PkgPath != selfOnlyGens[i].PkgPath { - t.Fatalf("generated package[%d] = %q, want %q", i, selfOnlyGens[i].PkgPath, baseGens[i].PkgPath) - } - if diff := cmp.Diff(string(baseGens[i].Content), string(selfOnlyGens[i].Content)); diff != "" { - t.Fatalf("generated content mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) - } - baseGenErrs := comparableGenerateErrors(baseGens[i].Errs) - selfOnlyGenErrs := comparableGenerateErrors(selfOnlyGens[i].Errs) - if diff := cmp.Diff(baseGenErrs, selfOnlyGenErrs); diff != "" { - t.Fatalf("generate errs mismatch for %q (-base +self_only):\n%s", baseGens[i].PkgPath, diff) - } - } -} - func comparableGenerateErrors(errs []error) []string { out := make([]string, len(errs)) for i, err := range errs { From 3bc75c443ed89a2855a4c7479d8412da2bccb12f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:06:56 -0500 Subject: [PATCH 11/82] feat: local caching from wire perspective --- internal/loader/custom.go | 48 ++- internal/semanticcache/cache.go | 143 +++++++ internal/wire/parse.go | 563 ++++++++++++++++++++++++++- internal/wire/parse_coverage_test.go | 228 +++++++++++ internal/wire/wire.go | 2 +- 5 files changed, 978 insertions(+), 6 deletions(-) create mode 100644 internal/semanticcache/cache.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 49fc217..d4a7f66 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -33,6 +33,8 @@ import ( "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" + + "github.com/goforj/wire/internal/semanticcache" ) type unsupportedError struct { @@ -525,7 +527,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } - useArtifact := loaderArtifactEnabled(l.env) && !isTarget && !isLocal + useArtifact := loaderArtifactEnabled(l.env) && !isTarget && (!isLocal || l.useLocalSemanticArtifact(meta)) if useArtifact { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() @@ -629,12 +631,23 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) - if shouldWriteArtifact(l.env, isTarget, isLocal) && len(pkg.Errors) == 0 { + if shouldWriteArtifact(l.env, isTarget) && len(pkg.Errors) == 0 { _ = l.writeArtifact(meta, tpkg, isLocal) } return pkg, nil } +func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) bool { + if meta == nil { + return false + } + art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) + if err != nil || art == nil { + return false + } + return art.Supported +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -650,15 +663,37 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) l.stats.artifactPath += time.Since(pathStart) if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) l.stats.artifactRead += time.Since(start) l.stats.artifactMisses++ return nil, false } + if isLocal { + preloadStart := time.Now() + for _, imp := range meta.Imports { + target := imp + if mapped := meta.ImportMap[imp]; mapped != "" { + target = mapped + } + dep, err := l.loadPackage(target) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=preload_dep_error dep=%s err=%v", path, isLocal, target, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } + if dep != nil && dep.Types != nil { + l.typesPkgs[target] = dep.Types + } + } + l.stats.artifactImportLink += time.Since(preloadStart) + } var tpkg *types.Package decodeStart := time.Now() tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) l.stats.artifactDecode += time.Since(decodeStart) if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=decode_error err=%v", path, isLocal, err) l.stats.artifactRead += time.Since(start) l.stats.artifactMisses++ return nil, false @@ -673,10 +708,12 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac start := time.Now() artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) if err != nil { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=path_error err=%v", meta.ImportPath, isLocal, err) l.stats.artifactWrite += time.Since(start) return err } if artifactUpToDate(l.env, artifactPath, meta, isLocal) { + debugf(l.ctx, "loader.artifact.write_skip pkg=%s local=%t reason=up_to_date", meta.ImportPath, isLocal) l.stats.artifactWrite += time.Since(start) return nil } @@ -684,6 +721,9 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac l.stats.artifactWrite += time.Since(start) if writeErr == nil { l.stats.artifactWrites++ + debugf(l.ctx, "loader.artifact.write_ok pkg=%s local=%t path=%s", meta.ImportPath, isLocal, artifactPath) + } else { + debugf(l.ctx, "loader.artifact.write_fail pkg=%s local=%t err=%v", meta.ImportPath, isLocal, writeErr) } if writeErr != nil { return writeErr @@ -691,8 +731,8 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac return nil } -func shouldWriteArtifact(env []string, isTarget, isLocal bool) bool { - if !loaderArtifactEnabled(env) || isTarget || isLocal { +func shouldWriteArtifact(env []string, isTarget bool) bool { + if !loaderArtifactEnabled(env) || isTarget { return false } return true diff --git a/internal/semanticcache/cache.go b/internal/semanticcache/cache.go new file mode 100644 index 0000000..4442415 --- /dev/null +++ b/internal/semanticcache/cache.go @@ -0,0 +1,143 @@ +package semanticcache + +import ( + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "strconv" +) + +const dirEnv = "WIRE_SEMANTIC_CACHE_DIR" + +type PackageArtifact struct { + Version int + PackagePath string + PackageName string + HasProviderSetVars bool + Supported bool + Vars map[string]ProviderSetArtifact +} + +type ProviderSetArtifact struct { + Items []ProviderSetItemArtifact +} + +type ProviderSetItemArtifact struct { + Kind string + ImportPath string + Name string + Type TypeRef + Type2 TypeRef + FieldNames []string + AllFields bool +} + +type TypeRef struct { + ImportPath string + Name string + Pointer int +} + +func ArtifactPath(env []string, importPath, packageName string, files []string) (string, error) { + dir, err := artifactDir(env) + if err != nil { + return "", err + } + key, err := artifactKey(importPath, packageName, files) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func Read(env []string, importPath, packageName string, files []string) (*PackageArtifact, error) { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var art PackageArtifact + if err := gob.NewDecoder(f).Decode(&art); err != nil { + return nil, err + } + return &art, nil +} + +func Write(env []string, importPath, packageName string, files []string, art *PackageArtifact) error { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(art) +} + +func Exists(env []string, importPath, packageName string, files []string) bool { + path, err := ArtifactPath(env, importPath, packageName, files) + if err != nil { + return false + } + _, err = os.Stat(path) + return err == nil +} + +func artifactDir(env []string) (string, error) { + for i := len(env) - 1; i >= 0; i-- { + key, val, ok := splitEnv(env[i]) + if ok && key == dirEnv && val != "" { + return val, nil + } + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "semantic-artifacts"), nil +} + +func artifactKey(importPath, packageName string, files []string) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-semantic-artifact-v1\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(importPath)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(packageName)) + sum.Write([]byte{'\n'}) + for _, name := range files { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func splitEnv(kv string) (string, string, bool) { + for i := 0; i < len(kv); i++ { + if kv[i] == '=' { + return kv[:i], kv[i+1:], true + } + } + return "", "", false +} diff --git a/internal/wire/parse.go b/internal/wire/parse.go index e6f8cb1..a825b4b 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -33,6 +33,7 @@ import ( "golang.org/x/tools/go/types/typeutil" "github.com/goforj/wire/internal/loader" + "github.com/goforj/wire/internal/semanticcache" ) // A providerSetSrc captures the source for a type provided by a ProviderSet. @@ -267,7 +268,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCache(pkgs) + oc := newObjectCacheWithEnv(pkgs, env) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { @@ -452,8 +453,10 @@ func (in *Injector) String() string { // objectCache is a lazily evaluated mapping of objects to Wire structures. type objectCache struct { fset *token.FileSet + env []string packages map[string]*packages.Package objects map[objRef]objCacheEntry + semantic map[string]*semanticcache.PackageArtifact hasher typeutil.Hasher } @@ -468,13 +471,19 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package) *objectCache { + return newObjectCacheWithEnv(pkgs, nil) +} + +func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } oc := &objectCache{ fset: pkgs[0].Fset, + env: append([]string(nil), env...), packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), + semantic: make(map[string]*semanticcache.PackageArtifact), hasher: typeutil.MakeHasher(), } // Depth-first search of all dependencies to gather import path to @@ -482,6 +491,7 @@ func newObjectCache(pkgs []*packages.Package) *objectCache { // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) + oc.recordSemanticArtifacts() return oc } @@ -528,6 +538,11 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) + if spec == nil && isProviderSetType(obj.Type()) { + if pset, ok, errs := oc.semanticProviderSet(obj); ok { + return pset, errs + } + } if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } @@ -546,6 +561,552 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } +func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { + pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return nil, false, nil + } + art := oc.semanticArtifact(pkg) + if art == nil || !art.Supported { + return nil, false, nil + } + setArt, ok := art.Vars[obj.Name()] + if !ok { + return nil, false, nil + } + pset := &ProviderSet{ + Pos: obj.Pos(), + PkgPath: obj.Pkg().Path(), + VarName: obj.Name(), + } + ec := new(errorCollector) + for _, item := range setArt.Items { + switch item.Kind { + case "func": + providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Providers = append(pset.Providers, providerObj) + case "set": + setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Imports = append(pset.Imports, setObj) + case "bind": + binding, errs := oc.semanticBinding(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Bindings = append(pset.Bindings, binding) + case "struct": + providerObj, errs := oc.semanticStructProvider(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Providers = append(pset.Providers, providerObj) + case "fields": + fields, errs := oc.semanticFields(item) + if len(errs) > 0 { + ec.add(errs...) + continue + } + pset.Fields = append(pset.Fields, fields...) + default: + ec.add(fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)) + } + } + if len(ec.errors) > 0 { + return nil, true, ec.errors + } + var errs []error + pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) + if len(errs) > 0 { + return nil, true, errs + } + if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { + return nil, true, errs + } + return pset, true, nil +} + +func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + } + obj := pkg.Types.Scope().Lookup(name) + fn, ok := obj.(*types.Func) + if !ok || fn == nil { + return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} + } + return processFuncProvider(oc.fset, fn) +} + +func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + } + obj := pkg.Types.Scope().Lookup(name) + v, ok := obj.(*types.Var) + if !ok || v == nil || !isProviderSetType(v.Type()) { + return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} + } + item, errs := oc.get(v) + if len(errs) > 0 { + return nil, errs + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + return nil, []error{fmt.Errorf("%s.%s did not resolve to a provider set", importPath, name)} + } + return pset, nil +} + +func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { + iface, err := oc.semanticType(item.Type) + if err != nil { + return nil, []error{err} + } + provided, err := oc.semanticType(item.Type2) + if err != nil { + return nil, []error{err} + } + return &IfaceBinding{ + Iface: iface, + Provided: provided, + }, nil +} + +func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { + typeName, err := oc.semanticTypeName(item.Type) + if err != nil { + return nil, []error{err} + } + out := typeName.Type() + st, ok := out.Underlying().(*types.Struct) + if !ok { + return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} + } + provider := &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: []types.Type{out, types.NewPointer(out)}, + } + if item.AllFields { + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + f := st.Field(i) + provider.Args = append(provider.Args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + } else { + for _, fieldName := range item.FieldNames { + f := lookupStructField(st, fieldName) + if f == nil { + return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} + } + provider.Args = append(provider.Args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + } + return provider, nil +} + +func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { + parent, err := oc.semanticType(item.Type) + if err != nil { + return nil, []error{err} + } + structType, ptrToField, err := structFromFieldsParent(parent) + if err != nil { + return nil, []error{err} + } + fields := make([]*Field, 0, len(item.FieldNames)) + for _, fieldName := range item.FieldNames { + v := lookupStructField(structType, fieldName) + if v == nil { + return nil, []error{fmt.Errorf("field %q not found", fieldName)} + } + out := []types.Type{v.Type()} + if ptrToField { + out = append(out, types.NewPointer(v.Type())) + } + fields = append(fields, &Field{ + Parent: parent, + Name: v.Name(), + Pkg: v.Pkg(), + Pos: v.Pos(), + Out: out, + }) + } + return fields, nil +} + +func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { + typeName, err := oc.semanticTypeName(ref) + if err != nil { + return nil, err + } + var typ types.Type = typeName.Type() + for i := 0; i < ref.Pointer; i++ { + typ = types.NewPointer(typ) + } + return typ, nil +} + +func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { + pkg := oc.packages[ref.ImportPath] + if pkg == nil || pkg.Types == nil { + return nil, fmt.Errorf("missing typed package for %s", ref.ImportPath) + } + obj := pkg.Types.Scope().Lookup(ref.Name) + typeName, ok := obj.(*types.TypeName) + if !ok || typeName == nil { + return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) + } + return typeName, nil +} + +func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { + ptr, ok := parent.(*types.Pointer) + if !ok { + return nil, false, fmt.Errorf("parent type %s is not a pointer", types.TypeString(parent, nil)) + } + switch t := ptr.Elem().Underlying().(type) { + case *types.Pointer: + st, ok := t.Elem().Underlying().(*types.Struct) + if !ok { + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } + return st, true, nil + case *types.Struct: + return t, false, nil + default: + return nil, false, fmt.Errorf("parent type %s does not point to a struct", types.TypeString(parent, nil)) + } +} + +func lookupStructField(st *types.Struct, name string) *types.Var { + for i := 0; i < st.NumFields(); i++ { + if st.Field(i).Name() == name { + return st.Field(i) + } + } + return nil +} + +func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if pkg == nil { + return nil + } + if art, ok := oc.semantic[pkg.PkgPath]; ok { + return art + } + if len(oc.env) == 0 || len(pkg.GoFiles) == 0 { + return nil + } + art, err := semanticcache.Read(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles) + if err != nil { + return nil + } + oc.semantic[pkg.PkgPath] = art + return art +} + +func (oc *objectCache) recordSemanticArtifacts() { + if len(oc.env) == 0 { + return + } + for _, pkg := range oc.packages { + if pkg == nil || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil || len(pkg.GoFiles) == 0 { + continue + } + art := buildSemanticArtifact(pkg) + if art == nil { + continue + } + oc.semantic[pkg.PkgPath] = art + _ = semanticcache.Write(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles, art) + } +} + +func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { + return nil + } + art := &semanticcache.PackageArtifact{ + Version: 1, + PackagePath: pkg.PkgPath, + PackageName: pkg.Name, + Supported: true, + Vars: make(map[string]semanticcache.ProviderSetArtifact), + } + scope := pkg.Types.Scope() + for _, name := range scope.Names() { + obj := scope.Lookup(name) + v, ok := obj.(*types.Var) + if !ok || !isProviderSetType(v.Type()) { + continue + } + art.HasProviderSetVars = true + spec := semanticVarDecl(pkg, v) + if spec == nil || len(spec.Values) == 0 { + art.Supported = false + continue + } + var idx int + found := false + for i := range spec.Names { + if spec.Names[i].Name == v.Name() { + idx = i + found = true + break + } + } + if !found || idx >= len(spec.Values) { + art.Supported = false + continue + } + setArt, ok := summarizeSemanticProviderSet(pkg.TypesInfo, spec.Values[idx], pkg.PkgPath) + if !ok { + art.Supported = false + continue + } + art.Vars[v.Name()] = setArt + } + return art +} + +func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath string) (semanticcache.ProviderSetArtifact, bool) { + call, ok := astutil.Unparen(expr).(*ast.CallExpr) + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + fnObj := qualifiedIdentObject(info, call.Fun) + if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) || fnObj.Name() != "NewSet" { + return semanticcache.ProviderSetArtifact{}, false + } + setArt := semanticcache.ProviderSetArtifact{ + Items: make([]semanticcache.ProviderSetItemArtifact, 0, len(call.Args)), + } + for _, arg := range call.Args { + items, ok := summarizeSemanticProviderSetArg(info, astutil.Unparen(arg), pkgPath) + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + setArt.Items = append(setArt.Items, items...) + } + return setArt, true +} + +func summarizeSemanticProviderSetArg(info *types.Info, expr ast.Expr, pkgPath string) ([]semanticcache.ProviderSetItemArtifact, bool) { + if obj := qualifiedIdentObject(info, expr); obj != nil && obj.Pkg() != nil && obj.Exported() { + item := semanticcache.ProviderSetItemArtifact{ + ImportPath: obj.Pkg().Path(), + Name: obj.Name(), + } + switch typed := obj.(type) { + case *types.Func: + item.Kind = "func" + case *types.Var: + if !isProviderSetType(typed.Type()) { + return nil, false + } + item.Kind = "set" + default: + return nil, false + } + if item.ImportPath == "" { + item.ImportPath = pkgPath + } + return []semanticcache.ProviderSetItemArtifact{item}, true + } + call, ok := expr.(*ast.CallExpr) + if !ok { + return nil, false + } + fnObj := qualifiedIdentObject(info, call.Fun) + if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) { + return nil, false + } + switch fnObj.Name() { + case "NewSet": + nested, ok := summarizeSemanticProviderSet(info, call, pkgPath) + if !ok { + return nil, false + } + return nested.Items, true + case "Bind": + item, ok := summarizeSemanticBind(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + case "Struct": + item, ok := summarizeSemanticStruct(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + case "FieldsOf": + item, ok := summarizeSemanticFields(info, call) + if !ok { + return nil, false + } + return []semanticcache.ProviderSetItemArtifact{item}, true + default: + return nil, false + } +} + +func summarizeSemanticBind(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) != 2 { + return semanticcache.ProviderSetItemArtifact{}, false + } + iface, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) + if !ok || iface.Pointer == 0 { + return semanticcache.ProviderSetItemArtifact{}, false + } + iface.Pointer-- + providedType := info.TypeOf(call.Args[1]) + if bindShouldUsePointer(info, call) { + ptr, ok := providedType.(*types.Pointer) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + providedType = ptr.Elem() + } + provided, ok := summarizeTypeRef(providedType) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + return semanticcache.ProviderSetItemArtifact{ + Kind: "bind", + Type: iface, + Type2: provided, + }, true +} + +func summarizeSemanticStruct(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) < 1 { + return semanticcache.ProviderSetItemArtifact{}, false + } + structType := info.TypeOf(call.Args[0]) + ptr, ok := structType.(*types.Pointer) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + ref, ok := summarizeTypeRef(ptr.Elem()) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + item := semanticcache.ProviderSetItemArtifact{ + Kind: "struct", + Type: ref, + } + if allFields(call) { + item.AllFields = true + return item, true + } + item.FieldNames = make([]string, 0, len(call.Args)-1) + for i := 1; i < len(call.Args); i++ { + lit, ok := call.Args[i].(*ast.BasicLit) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + fieldName, err := strconv.Unquote(lit.Value) + if err != nil { + return semanticcache.ProviderSetItemArtifact{}, false + } + item.FieldNames = append(item.FieldNames, fieldName) + } + return item, true +} + +func summarizeSemanticFields(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { + if len(call.Args) < 2 { + return semanticcache.ProviderSetItemArtifact{}, false + } + parent, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + item := semanticcache.ProviderSetItemArtifact{ + Kind: "fields", + Type: parent, + FieldNames: make([]string, 0, len(call.Args)-1), + } + for i := 1; i < len(call.Args); i++ { + lit, ok := call.Args[i].(*ast.BasicLit) + if !ok { + return semanticcache.ProviderSetItemArtifact{}, false + } + fieldName, err := strconv.Unquote(lit.Value) + if err != nil { + return semanticcache.ProviderSetItemArtifact{}, false + } + item.FieldNames = append(item.FieldNames, fieldName) + } + return item, true +} + +func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { + ref := semanticcache.TypeRef{} + for { + ptr, ok := typ.(*types.Pointer) + if !ok { + break + } + ref.Pointer++ + typ = ptr.Elem() + } + named, ok := typ.(*types.Named) + if !ok { + return semanticcache.TypeRef{}, false + } + obj := named.Obj() + if obj == nil || obj.Pkg() == nil { + return semanticcache.TypeRef{}, false + } + ref.ImportPath = obj.Pkg().Path() + ref.Name = obj.Name() + return ref, true +} + +func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { + pos := obj.Pos() + for _, f := range pkg.Syntax { + tokenFile := pkg.Fset.File(f.Pos()) + if tokenFile == nil { + continue + } + if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { + path, _ := astutil.PathEnclosingInterval(f, pos, pos) + for _, node := range path { + if spec, ok := node.(*ast.ValueSpec); ok { + return spec + } + } + } + } + return nil +} + // varDecl finds the declaration that defines the given variable. func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 7c7a3b7..3a23d18 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -22,6 +22,8 @@ import ( "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" + + "github.com/goforj/wire/internal/semanticcache" ) func TestFindInjectorBuildVariants(t *testing.T) { @@ -220,6 +222,232 @@ func TestProcessStructProviderDuplicateFields(t *testing.T) { } } +func TestSummarizeSemanticProviderSet(t *testing.T) { + t.Parallel() + + info := &types.Info{ + Uses: make(map[*ast.Ident]types.Object), + } + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireIdent := ast.NewIdent("wire") + newSetIdent := ast.NewIdent("NewSet") + info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) + info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) + + depPkg := types.NewPackage("example.com/dep", "dep") + fnIdent := ast.NewIdent("NewMessage") + info.Uses[fnIdent] = types.NewFunc(token.NoPos, depPkg, "NewMessage", types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depPkg, "", types.Typ[types.String])), false)) + + call := &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + fnIdent, + }, + } + got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") + if !ok { + t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + } + if len(got.Items) != 1 { + t.Fatalf("items len = %d, want 1", len(got.Items)) + } + if got.Items[0].Kind != "func" || got.Items[0].ImportPath != "example.com/dep" || got.Items[0].Name != "NewMessage" { + t.Fatalf("unexpected item: %+v", got.Items[0]) + } +} + +func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { + t.Parallel() + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Uses: make(map[*ast.Ident]types.Object), + } + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireIdent := ast.NewIdent("wire") + info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) + + appPkg := types.NewPackage("example.com/app", "app") + fooObj := types.NewTypeName(token.NoPos, appPkg, "Foo", nil) + fooNamed := types.NewNamed(fooObj, types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, appPkg, "Message", types.Typ[types.String]), + }, []string{""}), nil) + fooIfaceObj := types.NewTypeName(token.NoPos, appPkg, "Fooer", nil) + fooIface := types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) + + newSetIdent := ast.NewIdent("NewSet") + bindIdent := ast.NewIdent("Bind") + structIdent := ast.NewIdent("Struct") + fieldsIdent := ast.NewIdent("FieldsOf") + info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) + info.Uses[bindIdent] = types.NewFunc(token.NoPos, wirePkg, "Bind", nil) + info.Uses[structIdent] = types.NewFunc(token.NoPos, wirePkg, "Struct", nil) + info.Uses[fieldsIdent] = types.NewFunc(token.NoPos, wirePkg, "FieldsOf", nil) + + newFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Foo")}} + info.Types[newFooCall] = types.TypeAndValue{Type: types.NewPointer(fooNamed)} + newFooIfaceCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Fooer")}} + info.Types[newFooIfaceCall] = types.TypeAndValue{Type: types.NewPointer(fooIface)} + ptrToPtrFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("FooPtr")}} + info.Types[ptrToPtrFooCall] = types.TypeAndValue{Type: types.NewPointer(types.NewPointer(fooNamed))} + + call := &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: bindIdent}, Args: []ast.Expr{newFooIfaceCall, newFooCall}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, + }, + } + got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") + if !ok { + t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + } + if len(got.Items) != 3 { + t.Fatalf("items len = %d, want 3", len(got.Items)) + } + if got.Items[0].Kind != "bind" || got.Items[1].Kind != "struct" || got.Items[2].Kind != "fields" { + t.Fatalf("unexpected kinds: %+v", got.Items) + } +} + +func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) + wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) + + depTypes := types.NewPackage("example.com/dep", "dep") + msgFnSig := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depTypes, "", types.Typ[types.String])), false) + msgFn := types.NewFunc(token.NoPos, depTypes, "NewMessage", msgFnSig) + setVar := types.NewVar(token.NoPos, depTypes, "Set", wireNamed) + depTypes.Scope().Insert(msgFn) + depTypes.Scope().Insert(setVar) + + depPkg := &packages.Package{ + Name: "dep", + PkgPath: depTypes.Path(), + Types: depTypes, + Fset: fset, + Imports: make(map[string]*packages.Package), + } + oc := &objectCache{ + fset: fset, + packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, + objects: make(map[objRef]objCacheEntry), + semantic: map[string]*semanticcache.PackageArtifact{ + depPkg.PkgPath: { + Version: 1, + PackagePath: depPkg.PkgPath, + PackageName: depPkg.Name, + Supported: true, + Vars: map[string]semanticcache.ProviderSetArtifact{ + "Set": { + Items: []semanticcache.ProviderSetItemArtifact{ + {Kind: "func", ImportPath: depPkg.PkgPath, Name: "NewMessage"}, + }, + }, + }, + }, + }, + hasher: typeutil.MakeHasher(), + } + item, errs := oc.get(setVar) + if len(errs) > 0 { + t.Fatalf("oc.get(Set) errs = %v", errs) + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) + } + if len(pset.Providers) != 1 || pset.Providers[0].Name != "NewMessage" { + t.Fatalf("unexpected providers: %+v", pset.Providers) + } +} + +func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { + t.Parallel() + + fset := token.NewFileSet() + wirePkg := types.NewPackage("github.com/goforj/wire", "wire") + wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) + wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) + + appTypes := types.NewPackage("example.com/app", "app") + fooIfaceObj := types.NewTypeName(token.NoPos, appTypes, "Fooer", nil) + _ = types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) + fooObj := types.NewTypeName(token.NoPos, appTypes, "Foo", nil) + _ = types.NewNamed(fooObj, types.NewStruct([]*types.Var{ + types.NewVar(token.NoPos, appTypes, "Message", types.Typ[types.String]), + }, []string{""}), nil) + setVar := types.NewVar(token.NoPos, appTypes, "Set", wireNamed) + appTypes.Scope().Insert(fooIfaceObj) + appTypes.Scope().Insert(fooObj) + appTypes.Scope().Insert(setVar) + + appPkg := &packages.Package{ + Name: "app", + PkgPath: appTypes.Path(), + Types: appTypes, + Fset: fset, + Imports: make(map[string]*packages.Package), + } + oc := &objectCache{ + fset: fset, + packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, + objects: make(map[objRef]objCacheEntry), + semantic: map[string]*semanticcache.PackageArtifact{ + appPkg.PkgPath: { + Version: 1, + PackagePath: appPkg.PkgPath, + PackageName: appPkg.Name, + Supported: true, + Vars: map[string]semanticcache.ProviderSetArtifact{ + "Set": { + Items: []semanticcache.ProviderSetItemArtifact{ + { + Kind: "bind", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Fooer"}, + Type2: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, + }, + { + Kind: "struct", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, + AllFields: true, + }, + { + Kind: "fields", + Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo", Pointer: 2}, + FieldNames: []string{"Message"}, + }, + }, + }, + }, + }, + }, + hasher: typeutil.MakeHasher(), + } + item, errs := oc.get(setVar) + if len(errs) > 0 { + t.Fatalf("oc.get(Set) errs = %v", errs) + } + pset, ok := item.(*ProviderSet) + if !ok || pset == nil { + t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) + } + if len(pset.Bindings) != 1 { + t.Fatalf("bindings len = %d, want 1", len(pset.Bindings)) + } + if len(pset.Providers) != 1 || !pset.Providers[0].IsStruct { + t.Fatalf("providers = %+v, want one struct provider", pset.Providers) + } + if len(pset.Fields) != 1 || pset.Fields[0].Name != "Message" { + t.Fatalf("fields = %+v, want Message field", pset.Fields) + } +} + func TestProcessFuncProviderErrors(t *testing.T) { t.Parallel() diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 09bf814..99062ac 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -121,7 +121,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") g := newGen(pkg) - oc := newObjectCache([]*packages.Package{pkg}) + oc := newObjectCacheWithEnv([]*packages.Package{pkg}, env) injectorStart := time.Now() injectorFiles, genErrs := generateInjectors(oc, g, pkg) logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) From a8c2a0212bd481a18ba25525800407c903af9679 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:40:37 -0500 Subject: [PATCH 12/82] feat: go dep cache --- internal/loader/custom.go | 169 ++++++++++-- internal/loader/discovery.go | 4 + internal/loader/discovery_cache.go | 331 ++++++++++++++++++++++++ internal/loader/discovery_cache_test.go | 126 +++++++++ internal/wire/profile_bench_test.go | 32 +++ 5 files changed, 647 insertions(+), 15 deletions(-) create mode 100644 internal/loader/discovery_cache.go create mode 100644 internal/loader/discovery_cache_test.go create mode 100644 internal/wire/profile_bench_test.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index d4a7f66..c938acc 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -29,6 +29,7 @@ import ( "runtime/pprof" "sort" "strings" + "sync" "time" "golang.org/x/tools/go/gcexportdata" @@ -82,9 +83,18 @@ type customTypedGraphLoader struct { importer types.Importer loading map[string]bool isLocalCache map[string]bool + localSemanticOK map[string]bool + artifactPrefetch map[string]artifactPrefetchEntry stats typedLoadStats } +type artifactPrefetchEntry struct { + path string + data []byte + err error + ok bool +} + type typedLoadStats struct { read time.Duration parse time.Duration @@ -106,6 +116,9 @@ type typedLoadStats struct { artifactDecode time.Duration artifactImportLink time.Duration artifactWrite time.Duration + artifactPrefetch time.Duration + rootLoad time.Duration + discovery time.Duration artifactHits int artifactMisses int artifactWrites int @@ -225,6 +238,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz meta map[string]*packageMeta err error ) + discoveryStart := time.Now() if req.Discovery != nil && len(req.Discovery.meta) > 0 { meta = req.Discovery.meta } else { @@ -239,6 +253,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz return nil, err } } + discoveryDuration := time.Since(discoveryStart) if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } @@ -259,12 +274,19 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - stats: typedLoadStats{}, - } + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } + prefetchStart := time.Now() + l.prefetchArtifacts() + l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() root, err := l.loadPackage(req.Package) if err != nil { return nil, err } + l.stats.rootLoad = time.Since(rootLoadStart) logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) @@ -279,6 +301,9 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) + logDuration(ctx, "loader.custom.lazy.artifact_prefetch.wall", l.stats.artifactPrefetch) + logDuration(ctx, "loader.custom.lazy.root_load.wall", l.stats.rootLoad) + logDuration(ctx, "loader.custom.lazy.discovery.wall", l.stats.discovery) logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) @@ -289,6 +314,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { + discoveryStart := time.Now() meta, err := runGoList(ctx, goListRequest{ WD: req.WD, Env: req.Env, @@ -299,6 +325,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if err != nil { return nil, err } + discoveryDuration := time.Since(discoveryStart) if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } @@ -329,8 +356,14 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - stats: typedLoadStats{}, - } + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } + prefetchStart := time.Now() + l.prefetchArtifacts() + l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() roots := make([]*packages.Package, 0, len(targets)) for _, m := range meta { if m.DepOnly { @@ -342,6 +375,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } roots = append(roots, root) } + l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) @@ -357,6 +391,9 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) + logDuration(ctx, "loader.custom.typed.artifact_prefetch.wall", l.stats.artifactPrefetch) + logDuration(ctx, "loader.custom.typed.root_load.wall", l.stats.rootLoad) + logDuration(ctx, "loader.custom.typed.discovery.wall", l.stats.discovery) logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) @@ -527,7 +564,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } l.packages[path] = pkg } - useArtifact := loaderArtifactEnabled(l.env) && !isTarget && (!isLocal || l.useLocalSemanticArtifact(meta)) + useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) if useArtifact { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() @@ -641,10 +678,15 @@ func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) boo if meta == nil { return false } + if ok, exists := l.localSemanticOK[meta.ImportPath]; exists { + return ok + } art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) if err != nil || art == nil { + l.localSemanticOK[meta.ImportPath] = false return false } + l.localSemanticOK[meta.ImportPath] = art.Supported return art.Supported } @@ -659,14 +701,26 @@ func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, isLocal bool) (*types.Package, bool) { start := time.Now() - pathStart := time.Now() - artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) - l.stats.artifactPath += time.Since(pathStart) - if err != nil { - debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) - l.stats.artifactRead += time.Since(start) - l.stats.artifactMisses++ - return nil, false + entry, prefetched := l.artifactPrefetch[path] + artifactPath := "" + if prefetched { + artifactPath = entry.path + if entry.err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=prefetch_error err=%v", path, isLocal, entry.err) + l.stats.artifactMisses++ + return nil, false + } + } else { + pathStart := time.Now() + var err error + artifactPath, err = loaderArtifactPath(l.env, meta, isLocal) + l.stats.artifactPath += time.Since(pathStart) + if err != nil { + debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=path_error err=%v", path, isLocal, err) + l.stats.artifactRead += time.Since(start) + l.stats.artifactMisses++ + return nil, false + } } if isLocal { preloadStart := time.Now() @@ -690,7 +744,12 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is } var tpkg *types.Package decodeStart := time.Now() - tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + var err error + if prefetched { + tpkg, err = readLoaderArtifactData(entry.data, l.fset, l.typesPkgs, path) + } else { + tpkg, err = readLoaderArtifact(artifactPath, l.fset, l.typesPkgs, path) + } l.stats.artifactDecode += time.Since(decodeStart) if err != nil { debugf(l.ctx, "loader.artifact.read_miss pkg=%s local=%t reason=decode_error err=%v", path, isLocal, err) @@ -698,12 +757,92 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is l.stats.artifactMisses++ return nil, false } - l.stats.artifactRead += time.Since(start) + if !prefetched { + l.stats.artifactRead += time.Since(start) + } l.stats.artifactHits++ l.typesPkgs[path] = tpkg return tpkg, true } +func (l *customTypedGraphLoader) shouldUseArtifact(path string, meta *packageMeta, isTarget, isLocal bool) bool { + if !loaderArtifactEnabled(l.env) || isTarget { + return false + } + if !isLocal { + return true + } + return l.useLocalSemanticArtifact(meta) +} + +func (l *customTypedGraphLoader) prefetchArtifacts() { + if !loaderArtifactEnabled(l.env) { + return + } + candidates := make([]string, 0, len(l.meta)) + for path, meta := range l.meta { + _, isTarget := l.targets[path] + isLocal := l.isLocalPackage(path, meta) + if l.shouldUseArtifact(path, meta, isTarget, isLocal) { + candidates = append(candidates, path) + } + } + sort.Strings(candidates) + if len(candidates) == 0 { + return + } + type result struct { + pkg string + entry artifactPrefetchEntry + dur time.Duration + } + jobs := make(chan string, len(candidates)) + results := make(chan result, len(candidates)) + workers := 8 + if len(candidates) < workers { + workers = len(candidates) + } + var wg sync.WaitGroup + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for path := range jobs { + start := time.Now() + meta := l.meta[path] + isLocal := l.isLocalPackage(path, meta) + artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) + entry := artifactPrefetchEntry{path: artifactPath} + if err == nil { + data, readErr := os.ReadFile(artifactPath) + if readErr != nil { + entry.err = readErr + } else { + entry.data = data + entry.ok = true + } + } else { + entry.err = err + } + results <- result{pkg: path, entry: entry, dur: time.Since(start)} + } + }() + } + for _, path := range candidates { + jobs <- path + } + close(jobs) + wg.Wait() + close(results) + for res := range results { + l.artifactPrefetch[res.pkg] = res.entry + l.stats.artifactRead += res.dur + pathStart := time.Now() + _ = res.entry.path + l.stats.artifactPath += time.Since(pathStart) + } +} + func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Package, isLocal bool) error { start := time.Now() artifactPath, err := loaderArtifactPath(l.env, meta, isLocal) diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 9adbf4e..422a3a4 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -34,6 +34,9 @@ type goListRequest struct { } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + if cached, ok := readDiscoveryCache(req); ok { + return cached, nil + } args := []string{"list", "-json", "-e", "-compiled", "-export"} if req.NeedDeps { args = append(args, "-deps") @@ -91,5 +94,6 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, copyMeta := meta out[meta.ImportPath] = ©Meta } + writeDiscoveryCache(req, out) return out, nil } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go new file mode 100644 index 0000000..4ec9a12 --- /dev/null +++ b/internal/loader/discovery_cache.go @@ -0,0 +1,331 @@ +package loader + +import ( + "bytes" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "go/parser" + "go/token" + "os" + "path/filepath" + "runtime" + "sort" +) + +type discoveryCacheEntry struct { + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Workspace string + Meta map[string]*packageMeta + Global []discoveryFileMeta + LocalPkgs []discoveryLocalPackage +} + +type discoveryLocalPackage struct { + ImportPath string + Dir string + DirMeta discoveryDirMeta + Files []discoveryFileFingerprint +} + +type discoveryFileMeta struct { + Path string + Size int64 + ModTime int64 + IsDir bool +} + +type discoveryDirMeta struct { + Path string + Entries []string +} + +type discoveryFileFingerprint struct { + Path string + Hash string +} + +func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { + entry, err := loadDiscoveryCacheEntry(req) + if err != nil || entry == nil { + return nil, false + } + if !validateDiscoveryCacheEntry(entry) { + return nil, false + } + return clonePackageMetaMap(entry.Meta), true +} + +func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { + entry, err := buildDiscoveryCacheEntry(req, meta) + if err != nil { + return + } + _ = saveDiscoveryCacheEntry(req, entry) +} + +func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { + workspace := detectModuleRoot(req.WD) + entry := &discoveryCacheEntry{ + Version: 2, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + Workspace: workspace, + Meta: clonePackageMetaMap(meta), + } + global := []string{ + filepath.Join(workspace, "go.mod"), + filepath.Join(workspace, "go.sum"), + filepath.Join(workspace, "go.work"), + filepath.Join(workspace, "go.work.sum"), + } + for _, name := range global { + if fm, ok := statDiscoveryFile(name); ok { + entry.Global = append(entry.Global, fm) + } + } + locals := make([]discoveryLocalPackage, 0) + for _, pkg := range meta { + if pkg == nil || !isWorkspacePackage(workspace, pkg.Dir) { + continue + } + lp := discoveryLocalPackage{ + ImportPath: pkg.ImportPath, + Dir: pkg.Dir, + } + if fm, ok := statDiscoveryDir(pkg.Dir); ok { + lp.DirMeta = fm + } + for _, name := range metaFiles(pkg) { + if fm, ok := fingerprintDiscoveryFile(name); ok { + lp.Files = append(lp.Files, fm) + } + } + sort.Slice(lp.Files, func(i, j int) bool { return lp.Files[i].Path < lp.Files[j].Path }) + locals = append(locals, lp) + } + sort.Slice(locals, func(i, j int) bool { return locals[i].ImportPath < locals[j].ImportPath }) + entry.LocalPkgs = locals + return entry, nil +} + +func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { + if entry == nil || entry.Version != 2 { + return false + } + for _, fm := range entry.Global { + if !matchesDiscoveryFile(fm) { + return false + } + } + for _, lp := range entry.LocalPkgs { + if !matchesDiscoveryDir(lp.DirMeta) { + return false + } + for _, fm := range lp.Files { + if !matchesDiscoveryFingerprint(fm) { + return false + } + } + } + return true +} + +func discoveryCachePath(req goListRequest) (string, error) { + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + sumReq := struct { + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Go string + }{ + Version: 2, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + Go: runtime.Version(), + } + key, err := hashGob(sumReq) + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "discovery-cache", key+".gob"), nil +} + +func loadDiscoveryCacheEntry(req goListRequest) (*discoveryCacheEntry, error) { + path, err := discoveryCachePath(req) + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + var entry discoveryCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, err + } + return &entry, nil +} + +func saveDiscoveryCacheEntry(req goListRequest, entry *discoveryCacheEntry) error { + path, err := discoveryCachePath(req) + if err != nil { + return err + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func statDiscoveryFile(path string) (discoveryFileMeta, bool) { + info, err := os.Stat(path) + if err != nil { + return discoveryFileMeta{}, false + } + return discoveryFileMeta{ + Path: canonicalLoaderPath(path), + Size: info.Size(), + ModTime: info.ModTime().UnixNano(), + IsDir: info.IsDir(), + }, true +} + +func matchesDiscoveryFile(fm discoveryFileMeta) bool { + cur, ok := statDiscoveryFile(fm.Path) + if !ok { + return false + } + return cur.Size == fm.Size && cur.ModTime == fm.ModTime && cur.IsDir == fm.IsDir +} + +func statDiscoveryDir(path string) (discoveryDirMeta, bool) { + entries, err := os.ReadDir(path) + if err != nil { + return discoveryDirMeta{}, false + } + names := make([]string, 0, len(entries)) + for _, entry := range entries { + names = append(names, entry.Name()) + } + sort.Strings(names) + return discoveryDirMeta{ + Path: canonicalLoaderPath(path), + Entries: names, + }, true +} + +func matchesDiscoveryDir(dm discoveryDirMeta) bool { + cur, ok := statDiscoveryDir(dm.Path) + if !ok { + return false + } + if len(cur.Entries) != len(dm.Entries) { + return false + } + for i := range cur.Entries { + if cur.Entries[i] != dm.Entries[i] { + return false + } + } + return true +} + +func fingerprintDiscoveryFile(path string) (discoveryFileFingerprint, bool) { + src, err := os.ReadFile(path) + if err != nil { + return discoveryFileFingerprint{}, false + } + sum := sha256.New() + sum.Write([]byte(filepath.Base(path))) + sum.Write([]byte{0}) + file, err := parser.ParseFile(token.NewFileSet(), path, src, parser.ImportsOnly|parser.ParseComments) + if err != nil { + sum.Write(src) + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true + } + if offset := int(file.Package) - 1; offset > 0 && offset <= len(src) { + sum.Write(src[:offset]) + } + sum.Write([]byte(file.Name.Name)) + sum.Write([]byte{0}) + for _, imp := range file.Imports { + if imp.Name != nil { + sum.Write([]byte(imp.Name.Name)) + } + sum.Write([]byte{0}) + sum.Write([]byte(imp.Path.Value)) + sum.Write([]byte{0}) + } + return discoveryFileFingerprint{ + Path: canonicalLoaderPath(path), + Hash: hex.EncodeToString(sum.Sum(nil)), + }, true +} + +func matchesDiscoveryFingerprint(fp discoveryFileFingerprint) bool { + cur, ok := fingerprintDiscoveryFile(fp.Path) + if !ok { + return false + } + return cur.Hash == fp.Hash +} + +func hashGob(v interface{}) (string, error) { + var buf bytes.Buffer + if err := gob.NewEncoder(&buf).Encode(v); err != nil { + return "", err + } + sum := sha256.Sum256(buf.Bytes()) + return hex.EncodeToString(sum[:]), nil +} + +func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { + if len(in) == 0 { + return nil + } + out := make(map[string]*packageMeta, len(in)) + for k, v := range in { + if v == nil { + continue + } + cp := *v + cp.GoFiles = append([]string(nil), v.GoFiles...) + cp.CompiledGoFiles = append([]string(nil), v.CompiledGoFiles...) + cp.Imports = append([]string(nil), v.Imports...) + if v.ImportMap != nil { + cp.ImportMap = make(map[string]string, len(v.ImportMap)) + for mk, mv := range v.ImportMap { + cp.ImportMap[mk] = mv + } + } + if v.Error != nil { + errCopy := *v.Error + cp.Error = &errCopy + } + out[k] = &cp + } + return out +} diff --git a/internal/loader/discovery_cache_test.go b/internal/loader/discovery_cache_test.go new file mode 100644 index 0000000..953824f --- /dev/null +++ b/internal/loader/discovery_cache_test.go @@ -0,0 +1,126 @@ +package loader + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDiscoveryFingerprintIgnoresBodyOnlyEdits(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("before") +} +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "fmt" + +func Provide() string { + return fmt.Sprint("after") +} +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after body edit", path) + } + if fpBefore.Hash != fpAfter.Hash { + t.Fatalf("body-only edit changed fingerprint: %s != %s", fpBefore.Hash, fpAfter.Hash) + } +} + +func TestDiscoveryFingerprintDetectsImportChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `package example + +import "strings" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after import edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("import edit did not change fingerprint") + } +} + +func TestDiscoveryFingerprintDetectsHeaderBuildTagChange(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "pkg.go") + before := `//go:build linux + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(before), 0o644); err != nil { + t.Fatal(err) + } + fpBefore, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed", path) + } + after := `//go:build darwin + +package example + +import "fmt" +` + if err := os.WriteFile(path, []byte(after), 0o644); err != nil { + t.Fatal(err) + } + fpAfter, ok := fingerprintDiscoveryFile(path) + if !ok { + t.Fatalf("fingerprintDiscoveryFile(%q) failed after header edit", path) + } + if fpBefore.Hash == fpAfter.Hash { + t.Fatalf("build tag edit did not change fingerprint") + } +} + +func TestDiscoveryDirDetectsFileSetChange(t *testing.T) { + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "a.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + before, ok := statDiscoveryDir(dir) + if !ok { + t.Fatalf("statDiscoveryDir(%q) failed", dir) + } + if err := os.WriteFile(filepath.Join(dir, "b.go"), []byte("package example\n"), 0o644); err != nil { + t.Fatal(err) + } + if matchesDiscoveryDir(before) { + t.Fatalf("directory metadata did not detect added file") + } +} diff --git a/internal/wire/profile_bench_test.go b/internal/wire/profile_bench_test.go new file mode 100644 index 0000000..31dc7b7 --- /dev/null +++ b/internal/wire/profile_bench_test.go @@ -0,0 +1,32 @@ +package wire + +import ( + "context" + "os" + "testing" +) + +func BenchmarkGenerateRealAppWarmArtifacts(b *testing.B) { + root := os.Getenv("WIRE_REAL_APP_ROOT") + if root == "" { + b.Skip("WIRE_REAL_APP_ROOT not set") + } + artifactDir := b.TempDir() + env := append(os.Environ(), + "WIRE_LOADER_ARTIFACTS=1", + "WIRE_LOADER_ARTIFACT_DIR="+artifactDir, + ) + ctx := context.Background() + + // Warm the artifact cache once before measurement. + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("warm Generate errors: %v", errs) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, errs := Generate(ctx, root, env, []string{"."}, &GenerateOptions{}); len(errs) > 0 { + b.Fatalf("Generate errors: %v", errs) + } + } +} From 26f591bef82bc19092839dfae14971dc89646da8 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 06:54:13 -0500 Subject: [PATCH 13/82] feat(loader): cache unchanged root output --- internal/loader/custom.go | 1 + internal/wire/output_cache.go | 273 ++++++++++++++++++++++++++++++++++ internal/wire/wire.go | 5 + 3 files changed, 279 insertions(+) create mode 100644 internal/wire/output_cache.go diff --git a/internal/loader/custom.go b/internal/loader/custom.go index c938acc..dd2b9e0 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -184,6 +184,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes PkgPath: m.ImportPath, GoFiles: append([]string(nil), metaFiles(m)...), CompiledGoFiles: append([]string(nil), metaFiles(m)...), + ExportFile: m.Export, Imports: make(map[string]*packages.Package), } if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go new file mode 100644 index 0000000..7d384fb --- /dev/null +++ b/internal/wire/output_cache.go @@ -0,0 +1,273 @@ +package wire + +import ( + "context" + "crypto/sha256" + "encoding/gob" + "encoding/hex" + "os" + "path/filepath" + "runtime" + "sort" + "strconv" + "strings" + + "golang.org/x/tools/go/packages" + + "github.com/goforj/wire/internal/loader" +) + +const outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + +type outputCacheEntry struct { + Version int + Content []byte +} + +type outputCacheCandidate struct { + path string + outputPath string +} + +func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, bool) { + if !outputCacheEnabled(ctx, wd, env) { + debugf(ctx, "generate.output_cache=disabled") + return nil, nil, false + } + rootResult, err := loader.New().LoadRootGraph(withLoaderTiming(ctx), loader.RootLoadRequest{ + WD: wd, + Env: env, + Tags: opts.Tags, + Patterns: append([]string(nil), patterns...), + NeedDeps: true, + Mode: effectiveLoaderMode(ctx, wd, env), + }) + if err != nil || rootResult == nil || len(rootResult.Packages) == 0 { + if err != nil { + debugf(ctx, "generate.output_cache=load_root_error") + } else { + debugf(ctx, "generate.output_cache=no_roots") + } + return nil, nil, false + } + candidates := make(map[string]outputCacheCandidate, len(rootResult.Packages)) + results := make([]GenerateResult, 0, len(rootResult.Packages)) + for _, pkg := range rootResult.Packages { + outDir, err := detectOutputDir(pkg.GoFiles) + if err != nil { + debugf(ctx, "generate.output_cache=bad_output_dir") + return candidates, nil, false + } + key, err := outputCacheKey(wd, opts, pkg) + if err != nil { + debugf(ctx, "generate.output_cache=key_error") + return candidates, nil, false + } + path, err := outputCachePath(env, key) + if err != nil { + debugf(ctx, "generate.output_cache=path_error") + return candidates, nil, false + } + candidates[pkg.PkgPath] = outputCacheCandidate{ + path: path, + outputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + } + entry, ok := readOutputCache(path) + if !ok { + debugf(ctx, "generate.output_cache=miss") + return candidates, nil, false + } + results = append(results, GenerateResult{ + PkgPath: pkg.PkgPath, + OutputPath: filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go"), + Content: entry.Content, + }) + } + debugf(ctx, "generate.output_cache=hit") + return candidates, results, len(results) == len(rootResult.Packages) +} + +func writeGenerateOutputCache(candidates map[string]outputCacheCandidate, generated []GenerateResult) { + for _, gen := range generated { + candidate, ok := candidates[gen.PkgPath] + if !ok || candidate.path == "" || len(gen.Errs) > 0 || len(gen.Content) == 0 { + continue + } + _ = writeOutputCache(candidate.path, &outputCacheEntry{ + Version: 1, + Content: append([]byte(nil), gen.Content...), + }) + } +} + +func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { + if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { + return false + } + return envValue(env, "WIRE_LOADER_ARTIFACTS") == "1" +} + +func outputCachePath(env []string, key string) (string, error) { + dir, err := outputCacheDir(env) + if err != nil { + return "", err + } + return filepath.Join(dir, key+".gob"), nil +} + +func outputCacheDir(env []string) (string, error) { + if dir := envValue(env, outputCacheDirEnv); dir != "" { + return dir, nil + } + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire", "output-cache"), nil +} + +func readOutputCache(path string) (*outputCacheEntry, bool) { + f, err := os.Open(path) + if err != nil { + return nil, false + } + defer f.Close() + var entry outputCacheEntry + if err := gob.NewDecoder(f).Decode(&entry); err != nil { + return nil, false + } + if entry.Version != 1 { + return nil, false + } + return &entry, true +} + +func writeOutputCache(path string, entry *outputCacheEntry) error { + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + return gob.NewEncoder(f).Encode(entry) +} + +func outputCacheKey(wd string, opts *GenerateOptions, root *packages.Package) (string, error) { + sum := sha256.New() + sum.Write([]byte("wire-output-cache-v1\n")) + sum.Write([]byte(runtime.Version())) + sum.Write([]byte{'\n'}) + sum.Write([]byte(canonicalWirePath(wd))) + sum.Write([]byte{'\n'}) + sum.Write(opts.Header) + sum.Write([]byte{'\n'}) + sum.Write([]byte(opts.Tags)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(root.PkgPath)) + sum.Write([]byte{'\n'}) + workspace := detectWireModuleRoot(wd) + pkgs := reachablePackages(root) + for _, pkg := range pkgs { + sum.Write([]byte(pkg.PkgPath)) + sum.Write([]byte{'\n'}) + if isLocalWirePackage(workspace, pkg) { + files := append([]string(nil), pkg.GoFiles...) + sort.Strings(files) + for _, name := range files { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + if pkg.PkgPath == root.PkgPath { + src, err := os.ReadFile(name) + if err != nil { + return "", err + } + sum.Write(src) + sum.Write([]byte{'\n'}) + } + } + continue + } + sum.Write([]byte(pkg.ExportFile)) + sum.Write([]byte{'\n'}) + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +func reachablePackages(root *packages.Package) []*packages.Package { + seen := map[string]bool{} + var out []*packages.Package + var walk func(*packages.Package) + walk = func(pkg *packages.Package) { + if pkg == nil || seen[pkg.PkgPath] { + return + } + seen[pkg.PkgPath] = true + out = append(out, pkg) + paths := make([]string, 0, len(pkg.Imports)) + for path := range pkg.Imports { + paths = append(paths, path) + } + sort.Strings(paths) + for _, path := range paths { + walk(pkg.Imports[path]) + } + } + walk(root) + sort.Slice(out, func(i, j int) bool { return out[i].PkgPath < out[j].PkgPath }) + return out +} + +func isLocalWirePackage(workspace string, pkg *packages.Package) bool { + if pkg == nil || len(pkg.GoFiles) == 0 { + return false + } + dir := filepath.Dir(pkg.GoFiles[0]) + dir = canonicalWirePath(dir) + workspace = canonicalWirePath(workspace) + if dir == workspace { + return true + } + return len(dir) > len(workspace) && dir[:len(workspace)] == workspace && dir[len(workspace)] == filepath.Separator +} + +func detectWireModuleRoot(start string) string { + start = canonicalWirePath(start) + for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { + if info, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !info.IsDir() { + return dir + } + next := filepath.Dir(dir) + if next == dir { + break + } + } + return start +} + +func canonicalWirePath(path string) string { + path = filepath.Clean(path) + if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { + return filepath.Clean(resolved) + } + return path +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key { + return value + } + } + return "" +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 99062ac..3d787f3 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -102,6 +102,10 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } + cacheCandidates, cached, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) + if ok { + return cached, nil + } loadStart := time.Now() pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) logTiming(ctx, "generate.load", loadStart) @@ -149,6 +153,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o generated[i].Content = goSrc logTiming(ctx, "generate.package."+pkg.PkgPath+".total", pkgStart) } + writeGenerateOutputCache(cacheCandidates, generated) return generated, nil } From a357a384f501cc87d0687b497e07b67e947744ee Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 08:18:47 -0500 Subject: [PATCH 14/82] chore: bench tweaks --- internal/loader/artifact_cache.go | 2 +- internal/loader/discovery.go | 35 +++- internal/wire/import_bench_test.go | 311 +++++++++++++++++++++++++++++ internal/wire/output_cache.go | 2 +- scripts/import-benchmarks.sh | 33 +++ 5 files changed, 379 insertions(+), 4 deletions(-) create mode 100644 internal/wire/import_bench_test.go create mode 100755 scripts/import-benchmarks.sh diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index d495143..42293cb 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -34,7 +34,7 @@ const ( ) func loaderArtifactEnabled(env []string) bool { - return envValue(env, loaderArtifactEnv) == "1" + return envValue(env, loaderArtifactEnv) != "0" } func loaderArtifactDir(env []string) (string, error) { diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 422a3a4..22b34d3 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -23,6 +23,7 @@ import ( "os" "os/exec" "path/filepath" + "time" ) type goListRequest struct { @@ -34,10 +35,18 @@ type goListRequest struct { } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { + cacheReadStart := time.Now() if cached, ok := readDiscoveryCache(req); ok { + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + logDuration(ctx, "loader.discovery.golist.wall", 0) + logDuration(ctx, "loader.discovery.decode.wall", 0) + logDuration(ctx, "loader.discovery.canonicalize.wall", 0) + logDuration(ctx, "loader.discovery.cache_build.wall", 0) + logDuration(ctx, "loader.discovery.cache_write.wall", 0) return cached, nil } - args := []string{"list", "-json", "-e", "-compiled", "-export"} + logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) + args := []string{"list", "-json", "-e", "-compiled"} if req.NeedDeps { args = append(args, "-deps") } @@ -60,22 +69,30 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, var stderr bytes.Buffer cmd.Stdout = &stdout cmd.Stderr = &stderr + goListStart := time.Now() if err := cmd.Run(); err != nil { return nil, fmt.Errorf("go list: %w: %s", err, stderr.String()) } + goListDuration := time.Since(goListStart) dec := json.NewDecoder(&stdout) out := make(map[string]*packageMeta) + var decodeDuration time.Duration + var canonicalizeDuration time.Duration for { var meta packageMeta + decodeStart := time.Now() if err := dec.Decode(&meta); err != nil { + decodeDuration += time.Since(decodeStart) if err == io.EOF { break } return nil, err } + decodeDuration += time.Since(decodeStart) if meta.ImportPath == "" { continue } + canonicalizeStart := time.Now() meta.Dir = canonicalLoaderPath(meta.Dir) for i, name := range meta.GoFiles { if !filepath.IsAbs(name) { @@ -91,9 +108,23 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, meta.Export = filepath.Join(meta.Dir, meta.Export) } meta.Imports = normalizeImports(meta.Imports, meta.ImportMap) + canonicalizeDuration += time.Since(canonicalizeStart) copyMeta := meta out[meta.ImportPath] = ©Meta } - writeDiscoveryCache(req, out) + cacheBuildStart := time.Now() + entry, err := buildDiscoveryCacheEntry(req, out) + cacheBuildDuration := time.Since(cacheBuildStart) + if err == nil && entry != nil { + cacheWriteStart := time.Now() + _ = saveDiscoveryCacheEntry(req, entry) + logDuration(ctx, "loader.discovery.cache_write.wall", time.Since(cacheWriteStart)) + } else { + logDuration(ctx, "loader.discovery.cache_write.wall", 0) + } + logDuration(ctx, "loader.discovery.golist.wall", goListDuration) + logDuration(ctx, "loader.discovery.decode.wall", decodeDuration) + logDuration(ctx, "loader.discovery.canonicalize.wall", canonicalizeDuration) + logDuration(ctx, "loader.discovery.cache_build.wall", cacheBuildDuration) return out, nil } diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go new file mode 100644 index 0000000..e134bc2 --- /dev/null +++ b/internal/wire/import_bench_test.go @@ -0,0 +1,311 @@ +package wire + +import ( + "archive/tar" + "bytes" + "context" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +const ( + importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" + stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" + stockWireModulePath = "github.com/google/wire" + currentWireModulePath = "github.com/goforj/wire" +) + +type importBenchRow struct { + imports int + stockCold time.Duration + currentCold time.Duration + currentWarm time.Duration +} + +const importBenchTrials = 3 + +func TestPrintImportScaleBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchEnv) != "1" { + t.Skipf("%s not set", importBenchEnv) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + sizes := []int{10, 100, 1000} + rows := make([]importBenchRow, 0, len(sizes)) + for _, n := range sizes { + stockFixture := createImportBenchFixture(t, n, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, n, currentWireModulePath, repoRoot) + rows = append(rows, importBenchRow{ + imports: n, + stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)), + currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)), + currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)), + }) + } + printImportBenchTable(t, rows) +} + +func buildWireBinary(t *testing.T, dir, name string) string { + t.Helper() + out := filepath.Join(t.TempDir(), name) + cmd := exec.Command("go", "build", "-o", out, "./cmd/wire") + cmd.Dir = dir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("build wire binary in %s: %v\n%s", dir, err, output) + } + return out +} + +func extractStockWire(t *testing.T, repoRoot, commit string) string { + t.Helper() + tmp := t.TempDir() + cmd := exec.Command("git", "archive", "--format=tar", commit) + cmd.Dir = repoRoot + stdout, err := cmd.StdoutPipe() + if err != nil { + t.Fatal(err) + } + if err := cmd.Start(); err != nil { + t.Fatalf("git archive start: %v", err) + } + tr := tar.NewReader(stdout) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("read stock tar: %v", err) + } + target := filepath.Join(tmp, hdr.Name) + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, os.FileMode(hdr.Mode)); err != nil { + t.Fatalf("mkdir %s: %v", target, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + t.Fatalf("mkdir parent %s: %v", target, err) + } + f, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + t.Fatalf("create %s: %v", target, err) + } + if _, err := io.Copy(f, tr); err != nil { + _ = f.Close() + t.Fatalf("write %s: %v", target, err) + } + if err := f.Close(); err != nil { + t.Fatalf("close %s: %v", target, err) + } + } + } + if err := cmd.Wait(); err != nil { + t.Fatalf("git archive wait: %v", err) + } + return tmp +} + +func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(importBenchGoMod(wireModulePath, wireReplaceDir)), 0o644); err != nil { + t.Fatal(err) + } + for i := 0; i < imports; i++ { + dir := filepath.Join(root, fmt.Sprintf("dep%04d", i)) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatal(err) + } + src := fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) + if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(src), 0o644); err != nil { + t.Fatal(err) + } + } + if err := os.MkdirAll(filepath.Join(root, "app"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(root, "app", "wire.go"), []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } + return filepath.Join(root, "app") +} + +func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { + t.Helper() + cmd := exec.Command(bin, "gen") + cmd.Dir = pkgDir + cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") + var stderr bytes.Buffer + cmd.Stdout = io.Discard + cmd.Stderr = &stderr + start := time.Now() + if err := cmd.Run(); err != nil { + t.Fatalf("run %s in %s: %v\n%s", bin, pkgDir, err, stderr.String()) + } + return time.Since(start) +} + +func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + } + return durations +} + +func runWarmTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + for i := 0; i < trials; i++ { + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + _ = runWireBenchCommand(t, bin, pkgDir, home, goCache) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + } + return durations +} + +func medianDuration(durations []time.Duration) time.Duration { + if len(durations) == 0 { + return 0 + } + sorted := append([]time.Duration(nil), durations...) + for i := 1; i < len(sorted); i++ { + for j := i; j > 0 && sorted[j] < sorted[j-1]; j-- { + sorted[j], sorted[j-1] = sorted[j-1], sorted[j] + } + } + return sorted[len(sorted)/2] +} + +func benchEnv(home, goCache string) []string { + env := append([]string(nil), os.Environ()...) + env = append(env, + "HOME="+home, + "GOCACHE="+goCache, + "GOMODCACHE=/tmp/gomodcache", + ) + return env +} + +func importBenchGoMod(wireModulePath, wireReplaceDir string) string { + return fmt.Sprintf(`module example.com/importbench + +go 1.26 + +require %s v0.0.0 + +replace %s => %s +`, wireModulePath, wireModulePath, wireReplaceDir) +} + +func importBenchWireFile(imports int, wireModulePath string) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package app\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t%[1]q\n", fmt.Sprintf("example.com/importbench/dep%04d", i))) + } + b.WriteString(")\n\n") + b.WriteString("type App struct{}\n\n") + b.WriteString("func provideApp(") + for i := 0; i < imports; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("d%d *dep%04d.T", i, i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(wire.NewSet(\n") + for i := 0; i < imports; i++ { + b.WriteString(fmt.Sprintf("\t\tdep%04d.Provide,\n", i)) + } + b.WriteString("\t\tprovideApp,\n\t))\n\treturn nil\n}\n") + return b.String() +} + +func printImportBenchTable(t *testing.T, rows []importBenchRow) { + t.Helper() + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + fmt.Println("| repo size | stock | current cold | current unchanged | cold speedup | unchanged speedup |") + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") + for _, row := range rows { + fmt.Printf("| %-9d | %-9s | %-12s | %-17s | %-12s | %-17s |\n", + row.imports, + formatMs(row.stockCold), + formatMs(row.currentCold), + formatMs(row.currentWarm), + formatSpeedup(row.stockCold, row.currentCold), + formatSpeedup(row.stockCold, row.currentWarm), + ) + } + fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") +} + +func formatMs(d time.Duration) string { + return fmt.Sprintf("%.1fms", float64(d)/float64(time.Millisecond)) +} + +func formatSpeedup(oldDur, newDur time.Duration) string { + if newDur == 0 { + return "inf" + } + return fmt.Sprintf("%.2fx", float64(oldDur)/float64(newDur)) +} + +func TestImportBenchFixtureGenerates(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + bin := buildWireBinary(t, repoRoot, "fixture-wire") + fixture := createImportBenchFixture(t, 10, currentWireModulePath, repoRoot) + _ = runWireBenchCommand(t, bin, fixture, t.TempDir(), filepath.Join(t.TempDir(), "gocache")) +} + +func TestImportBenchUsesStockArchive(t *testing.T) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + if _, err := os.Stat(filepath.Join(stockDir, "cmd", "wire", "main.go")); err != nil { + t.Fatalf("stock archive missing cmd/wire: %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, "go", "list", "./cmd/wire") + cmd.Dir = stockDir + cmd.Env = benchEnv(t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("stock archive not buildable: %v\n%s", err, out) + } +} + +func importBenchRepoRoot() (string, error) { + wd, err := os.Getwd() + if err != nil { + return "", err + } + return filepath.Clean(filepath.Join(wd, "..", "..")), nil +} diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index 7d384fb..35eacfe 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -104,7 +104,7 @@ func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { return false } - return envValue(env, "WIRE_LOADER_ARTIFACTS") == "1" + return envValue(env, "WIRE_LOADER_ARTIFACTS") != "0" } func outputCachePath(env []string, key string) (string, error) { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh new file mode 100755 index 0000000..f1f4e58 --- /dev/null +++ b/scripts/import-benchmarks.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$ROOT_DIR" + +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + +usage() { + cat <<'EOF' +Usage: + scripts/import-benchmarks.sh table + +Commands: + table Print the 10/100/1000 import stock-vs-current benchmark table. +EOF +} + +case "${1:-}" in + table) + WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v + ;; + ""|-h|--help|help) + usage + ;; + *) + echo "Unknown command: ${1}" >&2 + usage >&2 + exit 1 + ;; +esac From d6b36b16bcb8465a12602def5c9691dd1a4eb33c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 15:52:16 -0500 Subject: [PATCH 15/82] chore: re-implement cache --- cmd/wire/cache_cmd.go | 181 +++++++++++++++++++++++++++++ cmd/wire/cache_cmd_test.go | 96 +++++++++++++++ cmd/wire/main.go | 2 + internal/wire/import_bench_test.go | 64 +++++++++- scripts/import-benchmarks.sh | 7 +- 5 files changed, 347 insertions(+), 3 deletions(-) create mode 100644 cmd/wire/cache_cmd.go create mode 100644 cmd/wire/cache_cmd_test.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go new file mode 100644 index 0000000..cdbbd40 --- /dev/null +++ b/cmd/wire/cache_cmd.go @@ -0,0 +1,181 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/google/subcommands" +) + +const ( + loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" + outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + semanticCacheDirEnv = "WIRE_SEMANTIC_CACHE_DIR" +) + +var osUserCacheDir = os.UserCacheDir + +type cacheCmd struct { + clear bool +} + +type cacheTarget struct { + name string + path string +} + +func (*cacheCmd) Name() string { return "cache" } + +func (*cacheCmd) Synopsis() string { + return "inspect or clear the wire cache" +} + +func (*cacheCmd) Usage() string { + return `cache +cache clear +cache -clear + + By default, prints the cache directory. With -clear or clear, removes all + Wire-managed cache files. +` +} + +func (cmd *cacheCmd) SetFlags(f *flag.FlagSet) { + f.BoolVar(&cmd.clear, "clear", false, "clear Wire caches") +} + +func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{}) subcommands.ExitStatus { + _ = ctx + clearRequested := cmd.clear + switch extra := f.Args(); len(extra) { + case 0: + if !clearRequested { + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + fmt.Fprintln(os.Stdout, root) + return subcommands.ExitSuccess + } + case 1: + if extra[0] == "clear" { + clearRequested = true + break + } + log.Printf("unknown cache action %q", extra[0]) + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + default: + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + if !clearRequested { + log.Println(strings.TrimSpace(cmd.Usage())) + return subcommands.ExitFailure + } + cleared, err := clearWireCaches(os.Environ()) + if err != nil { + log.Printf("failed to clear cache: %v\n", err) + return subcommands.ExitFailure + } + root, err := wireCacheRoot(os.Environ()) + if err != nil { + log.Println(err) + return subcommands.ExitFailure + } + if len(cleared) == 0 { + log.Printf("cleared cache at %s\n", root) + return subcommands.ExitSuccess + } + log.Printf("cleared cache at %s\n", root) + return subcommands.ExitSuccess +} + +func wireCacheRoot(env []string) (string, error) { + base, err := osUserCacheDir() + if err != nil { + return "", fmt.Errorf("resolve user cache dir: %w", err) + } + return filepath.Join(base, "wire"), nil +} + +func clearWireCaches(env []string) ([]string, error) { + base, err := wireCacheRoot(env) + if err != nil { + return nil, err + } + targets := wireCacheTargets(env, filepath.Dir(base)) + cleared := make([]string, 0, len(targets)) + for _, target := range targets { + info, err := os.Stat(target.path) + if os.IsNotExist(err) { + continue + } + if err != nil { + return cleared, fmt.Errorf("stat %s cache: %w", target.name, err) + } + if !info.IsDir() { + if err := os.Remove(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + } else if err := os.RemoveAll(target.path); err != nil { + return cleared, fmt.Errorf("remove %s cache: %w", target.name, err) + } + cleared = append(cleared, target.name) + } + return cleared, nil +} + +func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { + baseWire := filepath.Join(userCacheDir, "wire") + targets := []cacheTarget{ + {name: "loader-artifacts", path: envValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, + {name: "discovery-cache", path: filepath.Join(baseWire, "discovery-cache")}, + {name: "semantic-artifacts", path: envValueDefault(env, semanticCacheDirEnv, filepath.Join(baseWire, "semantic-artifacts"))}, + {name: "output-cache", path: envValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, + } + seen := make(map[string]bool, len(targets)) + deduped := make([]cacheTarget, 0, len(targets)) + for _, target := range targets { + cleaned := filepath.Clean(target.path) + if seen[cleaned] { + continue + } + seen[cleaned] = true + target.path = cleaned + deduped = append(deduped, target) + } + sort.Slice(deduped, func(i, j int) bool { return deduped[i].name < deduped[j].name }) + return deduped +} + +func envValueDefault(env []string, key, fallback string) string { + for i := len(env) - 1; i >= 0; i-- { + parts := strings.SplitN(env[i], "=", 2) + if len(parts) == 2 && parts[0] == key && parts[1] != "" { + return parts[1] + } + } + return fallback +} diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go new file mode 100644 index 0000000..c0c74c1 --- /dev/null +++ b/cmd/wire/cache_cmd_test.go @@ -0,0 +1,96 @@ +package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestWireCacheTargetsDefault(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + got := wireCacheTargets(nil, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), + "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), + } + if len(got) != len(want) { + t.Fatalf("targets len = %d, want %d", len(got), len(want)) + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestWireCacheRoot(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + old := osUserCacheDir + osUserCacheDir = func() (string, error) { return base, nil } + defer func() { osUserCacheDir = old }() + + got, err := wireCacheRoot(nil) + if err != nil { + t.Fatalf("wireCacheRoot() error = %v", err) + } + want := filepath.Join(base, "wire") + if got != want { + t.Fatalf("wireCacheRoot() = %q, want %q", got, want) + } +} + +func TestWireCacheTargetsRespectOverrides(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), + } + got := wireCacheTargets(env, base) + want := map[string]string{ + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), + "semantic-artifacts": filepath.Join(base, "semantic"), + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + +func TestClearWireCachesRemovesTargets(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + env := []string{ + loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + outputCacheDirEnv + "=" + filepath.Join(base, "output"), + semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), + } + for _, target := range wireCacheTargets(env, base) { + if err := os.MkdirAll(target.path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q): %v", target.path, err) + } + if err := os.WriteFile(filepath.Join(target.path, "marker"), []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile(%q): %v", target.path, err) + } + } + old := osUserCacheDir + osUserCacheDir = func() (string, error) { return base, nil } + defer func() { osUserCacheDir = old }() + + cleared, err := clearWireCaches(env) + if err != nil { + t.Fatalf("clearWireCaches() error = %v", err) + } + if len(cleared) != 4 { + t.Fatalf("cleared len = %d, want 4", len(cleared)) + } + for _, target := range wireCacheTargets(env, base) { + if _, err := os.Stat(target.path); !os.IsNotExist(err) { + t.Fatalf("%s still exists after clear, stat err = %v", target.path, err) + } + } +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index f7fd92f..515673c 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -49,6 +49,7 @@ func main() { subcommands.Register(subcommands.CommandsCommand(), "") subcommands.Register(subcommands.FlagsCommand(), "") subcommands.Register(subcommands.HelpCommand(), "") + subcommands.Register(&cacheCmd{}, "") subcommands.Register(&checkCmd{}, "") subcommands.Register(&diffCmd{}, "") subcommands.Register(&genCmd{}, "") @@ -69,6 +70,7 @@ func main() { "commands": true, // builtin "help": true, // builtin "flags": true, // builtin + "cache": true, "check": true, "diff": true, "gen": true, diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index e134bc2..770e938 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -16,6 +16,7 @@ import ( const ( importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" + importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -57,6 +58,57 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { printImportBenchTable(t, rows) } +func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchBreakdown) != "1" { + t.Skipf("%s not set", importBenchBreakdown) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + const imports = 1000 + stockFixture := createImportBenchFixture(t, imports, stockWireModulePath, stockDir) + currentFixture := createImportBenchFixture(t, imports, currentWireModulePath, repoRoot) + + stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)) + currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)) + currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)) + + fmt.Printf("repo size: %d\n", imports) + fmt.Printf("stock cold: %s\n", formatMs(stockCold)) + fmt.Printf("current cold: %s\n", formatMs(currentCold)) + fmt.Printf("current unchanged: %s\n", formatMs(currentWarm)) + fmt.Printf("cold speedup: %s\n", formatSpeedup(stockCold, currentCold)) + fmt.Printf("unchanged speedup: %s\n", formatSpeedup(stockCold, currentWarm)) + fmt.Printf("cold gap: %s\n", formatMs(currentCold-stockCold)) + + home := t.TempDir() + goCache := filepath.Join(t.TempDir(), "gocache") + _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, home, goCache, "-timings") + fmt.Println("current cold timings:") + for _, line := range strings.Split(output, "\n") { + if !strings.Contains(line, "wire: timing:") { + continue + } + if strings.Contains(line, "loader.custom.root.discovery=") || + strings.Contains(line, "loader.discovery.") || + strings.Contains(line, "load.packages.load=") || + strings.Contains(line, "loader.custom.typed.artifact_write=") || + strings.Contains(line, "loader.custom.typed.root_load.wall=") || + strings.Contains(line, "loader.custom.typed.discovery.wall=") || + strings.Contains(line, "loader.custom.typed.artifact_writes=") || + strings.Contains(line, "generate.package.") || + strings.Contains(line, "wire.Generate=") || + strings.Contains(line, "total=") { + fmt.Println(line) + } + } +} + func buildWireBinary(t *testing.T, dir, name string) string { t.Helper() out := filepath.Join(t.TempDir(), name) @@ -147,7 +199,15 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { t.Helper() - cmd := exec.Command(bin, "gen") + d, _ := runWireBenchCommandOutput(t, bin, pkgDir, home, goCache) + return d +} + +func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, extraArgs ...string) (time.Duration, string) { + t.Helper() + args := []string{"gen"} + args = append(args, extraArgs...) + cmd := exec.Command(bin, args...) cmd.Dir = pkgDir cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") var stderr bytes.Buffer @@ -157,7 +217,7 @@ func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.D if err := cmd.Run(); err != nil { t.Fatalf("run %s in %s: %v\n%s", bin, pkgDir, err, stderr.String()) } - return time.Since(start) + return time.Since(start), stderr.String() } func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index f1f4e58..e1c1f97 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,9 +12,11 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table + scripts/import-benchmarks.sh breakdown Commands: - table Print the 10/100/1000 import stock-vs-current benchmark table. + table Print the 10/100/1000 import stock-vs-current benchmark table. + breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -22,6 +24,9 @@ case "${1:-}" in table) WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; + breakdown) + WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v + ;; ""|-h|--help|help) usage ;; From 84087dd9dd619fb4d4f4d0316244a379e2e32822 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 17:59:33 -0500 Subject: [PATCH 16/82] fix: provider discovery --- internal/wire/parse.go | 13 ++++++++++++- internal/wire/parse_coverage_test.go | 21 +++++++-------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index a825b4b..08a3c8a 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -574,6 +574,11 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, if !ok { return nil, false, nil } + for _, item := range setArt.Items { + if item.Kind == "bind" { + return nil, false, nil + } + } pset := &ProviderSet{ Pos: obj.Pos(), PkgPath: obj.Pkg().Path(), @@ -1856,5 +1861,11 @@ func bindShouldUsePointer(info *types.Info, call *ast.CallExpr) bool { fun := call.Fun.(*ast.SelectorExpr) // wire.Bind pkgName := fun.X.(*ast.Ident) // wire wireName := info.ObjectOf(pkgName).(*types.PkgName) // wire package - return wireName.Imported().Scope().Lookup("bindToUsePointer") != nil + if imported := wireName.Imported(); imported != nil { + if isWireImport(imported.Path()) { + return true + } + return imported.Scope().Lookup("bindToUsePointer") != nil + } + return false } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 3a23d18..c3c4d8e 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -367,7 +367,7 @@ func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { } } -func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { +func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { t.Parallel() fset := token.NewFileSet() @@ -429,22 +429,15 @@ func TestObjectCacheSemanticProviderSetFallbackTypeOnlyForms(t *testing.T) { }, hasher: typeutil.MakeHasher(), } - item, errs := oc.get(setVar) + pset, ok, errs := oc.semanticProviderSet(setVar) if len(errs) > 0 { - t.Fatalf("oc.get(Set) errs = %v", errs) - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) - } - if len(pset.Bindings) != 1 { - t.Fatalf("bindings len = %d, want 1", len(pset.Bindings)) + t.Fatalf("semanticProviderSet(Set) errs = %v", errs) } - if len(pset.Providers) != 1 || !pset.Providers[0].IsStruct { - t.Fatalf("providers = %+v, want one struct provider", pset.Providers) + if ok { + t.Fatalf("semanticProviderSet(Set) ok = true, want false") } - if len(pset.Fields) != 1 || pset.Fields[0].Name != "Message" { - t.Fatalf("fields = %+v, want Message field", pset.Fields) + if pset != nil { + t.Fatalf("semanticProviderSet(Set) = %#v, want nil", pset) } } From e968dccc78c647fc54f3b47a3140cb545fbf91d5 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:01:43 -0500 Subject: [PATCH 17/82] chore: benchmark update --- internal/wire/import_bench_test.go | 1086 ++++++++++++++++++++++++++-- scripts/import-benchmarks.sh | 5 + 2 files changed, 1050 insertions(+), 41 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 770e938..aff71e2 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -4,6 +4,7 @@ import ( "archive/tar" "bytes" "context" + "encoding/json" "fmt" "io" "os" @@ -17,6 +18,8 @@ import ( const ( importBenchEnv = "WIRE_IMPORT_BENCH_TABLE" importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" + importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" + importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -29,6 +32,27 @@ type importBenchRow struct { currentWarm time.Duration } +type importBenchScenarioRow struct { + profile string + localCount int + stdlibCount int + externalCount int + name string + stock time.Duration + current time.Duration +} + +type benchCaches struct { + home string + goCache string +} + +type benchGraphCounts struct { + local int + stdlib int + external int +} + const importBenchTrials = 3 func TestPrintImportScaleBenchmarkTable(t *testing.T) { @@ -42,6 +66,8 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { currentBin := buildWireBinary(t, repoRoot, "current-wire") stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) sizes := []int{10, 100, 1000} rows := make([]importBenchRow, 0, len(sizes)) @@ -50,9 +76,9 @@ func TestPrintImportScaleBenchmarkTable(t *testing.T) { currentFixture := createImportBenchFixture(t, n, currentWireModulePath, repoRoot) rows = append(rows, importBenchRow{ imports: n, - stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)), - currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)), - currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)), + stockCold: medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)), + currentCold: medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), + currentWarm: medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)), }) } printImportBenchTable(t, rows) @@ -69,14 +95,16 @@ func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { currentBin := buildWireBinary(t, repoRoot, "current-wire") stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) const imports = 1000 stockFixture := createImportBenchFixture(t, imports, stockWireModulePath, stockDir) currentFixture := createImportBenchFixture(t, imports, currentWireModulePath, repoRoot) - stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, importBenchTrials)) - currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, importBenchTrials)) - currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, importBenchTrials)) + stockCold := medianDuration(runColdTrials(t, stockBin, stockFixture, stockCaches, importBenchTrials)) + currentCold := medianDuration(runColdTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) + currentWarm := medianDuration(runWarmTrials(t, currentBin, currentFixture, currentCaches, importBenchTrials)) fmt.Printf("repo size: %d\n", imports) fmt.Printf("stock cold: %s\n", formatMs(stockCold)) @@ -86,27 +114,202 @@ func TestPrintImportScaleBenchmarkBreakdown(t *testing.T) { fmt.Printf("unchanged speedup: %s\n", formatSpeedup(stockCold, currentWarm)) fmt.Printf("cold gap: %s\n", formatMs(currentCold-stockCold)) - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, home, goCache, "-timings") + prewarmGoBenchCache(t, currentFixture, currentCaches) + _, output := runWireBenchCommandOutput(t, currentBin, currentFixture, currentCaches, "-timings") fmt.Println("current cold timings:") - for _, line := range strings.Split(output, "\n") { - if !strings.Contains(line, "wire: timing:") { - continue - } - if strings.Contains(line, "loader.custom.root.discovery=") || - strings.Contains(line, "loader.discovery.") || - strings.Contains(line, "load.packages.load=") || - strings.Contains(line, "loader.custom.typed.artifact_write=") || - strings.Contains(line, "loader.custom.typed.root_load.wall=") || - strings.Contains(line, "loader.custom.typed.discovery.wall=") || - strings.Contains(line, "loader.custom.typed.artifact_writes=") || - strings.Contains(line, "generate.package.") || - strings.Contains(line, "wire.Generate=") || - strings.Contains(line, "total=") { - fmt.Println(line) - } + printScenarioTimingLines(output) +} + +func TestPrintImportScenarioBenchmarkTable(t *testing.T) { + if os.Getenv(importBenchScenarios) != "1" { + t.Skipf("%s not set", importBenchScenarios) } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + type appBenchProfile struct { + localPkgs int + depPkgs int + external bool + label string + } + profiles := []appBenchProfile{ + {localPkgs: 10, depPkgs: 25, label: "local"}, + {localPkgs: 10, depPkgs: 1000, label: "local-high"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, + } + rows := make([]importBenchScenarioRow, 0, len(profiles)*6) + for _, profile := range profiles { + shapeFixture := createAppShapeBenchFixture(t, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot) + shapeCounts := goListGraphCounts(t, shapeFixture, "example.com/appbench", newBenchCaches(t)) + rows = append(rows, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "cold run", + stock: medianDuration(runAppColdTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppColdTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "unchanged rerun", + stock: medianDuration(runAppWarmTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppWarmTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "body-only local edit", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "body", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "body", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "shape change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "shape", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "shape", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "import change", + stock: medianDuration(runAppScenarioTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, "import", importBenchTrials)), + current: medianDuration(runAppScenarioTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, "import", importBenchTrials)), + }, + importBenchScenarioRow{ + profile: profile.label, + localCount: shapeCounts.local, + stdlibCount: shapeCounts.stdlib, + externalCount: shapeCounts.external, + name: "known import toggle", + stock: medianDuration(runAppKnownToggleTrials(t, stockBin, profile.localPkgs, profile.depPkgs, profile.external, stockWireModulePath, stockDir, importBenchTrials)), + current: medianDuration(runAppKnownToggleTrials(t, currentBin, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot, importBenchTrials)), + }, + ) + } + printImportScenarioBenchTable(t, rows) +} + +func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { + if os.Getenv(importBenchScenarioBD) != "1" { + t.Skipf("%s not set", importBenchScenarioBD) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + stockDir := extractStockWire(t, repoRoot, stockWireCommit) + stockBin := buildWireBinary(t, stockDir, "stock-wire") + + const ( + localPkgs = 10 + depPkgs = 1000 + ) + + stockPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, stockWireModulePath, stockDir) + currentPkgDir := createAppShapeBenchFixture(t, localPkgs, depPkgs, false, currentWireModulePath, repoRoot) + stockCaches := newBenchCaches(t) + currentCaches := newBenchCaches(t) + + prewarmGoBenchCache(t, stockPkgDir, stockCaches) + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "shape") + _ = runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + writeAppShapeControllerFile(t, filepath.Dir(stockPkgDir), 0, "base") + stockDur := runWireBenchCommand(t, stockBin, stockPkgDir, stockCaches) + + prewarmGoBenchCache(t, currentPkgDir, currentCaches) + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "shape") + _ = runWireBenchCommand(t, currentBin, currentPkgDir, currentCaches) + writeAppShapeControllerFile(t, filepath.Dir(currentPkgDir), 0, "base") + currentDur, currentOutput := runWireBenchCommandOutput(t, currentBin, currentPkgDir, currentCaches, "-timings") + + fmt.Printf("scenario: local=%d dep=%d known import toggle\n", localPkgs, depPkgs) + fmt.Printf("stock: %s\n", formatMs(stockDur)) + fmt.Printf("current: %s\n", formatMs(currentDur)) + fmt.Printf("speedup: %s\n", formatSpeedup(stockDur, currentDur)) + fmt.Println("current timings:") + printScenarioTimingLines(currentOutput) +} + +func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + for i := 0; i < trials; i++ { + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppWarmTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppScenarioTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir, variant string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + pkgDir := createAppShapeBenchFixture(t, features, depPkgs, external, wireModulePath, wireReplaceDir) + caches := newBenchCaches(t) + root := filepath.Dir(pkgDir) + for i := 0; i < trials; i++ { + resetAppShapeBenchFixture(t, pkgDir, features) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations } func buildWireBinary(t *testing.T, dir, name string) string { @@ -122,6 +325,14 @@ func buildWireBinary(t *testing.T, dir, name string) string { return out } +func newBenchCaches(t *testing.T) benchCaches { + t.Helper() + return benchCaches{ + home: t.TempDir(), + goCache: filepath.Join(t.TempDir(), "gocache"), + } +} + func extractStockWire(t *testing.T, repoRoot, commit string) string { t.Helper() tmp := t.TempDir() @@ -183,8 +394,7 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep if err := os.MkdirAll(dir, 0o755); err != nil { t.Fatal(err) } - src := fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) - if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(src), 0o644); err != nil { + if err := os.WriteFile(filepath.Join(dir, "dep.go"), []byte(importBenchDepFile(i, "base")), 0o644); err != nil { t.Fatal(err) } } @@ -197,19 +407,614 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep return filepath.Join(root, "app") } -func runWireBenchCommand(t *testing.T, bin, pkgDir, home, goCache string) time.Duration { +func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { + t.Helper() + root := t.TempDir() + modulePath := "example.com/appbench" + if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(appShapeGoMod(modulePath, wireModulePath, wireReplaceDir, external)), 0o644); err != nil { + t.Fatal(err) + } + if external { + seedAppShapeExternalGoSum(t, root) + } + for i := 0; i < depPkgs; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("dep%04d", i), "dep.go"), appShapeDepFile(i)) + } + writeAppShapeFile(t, filepath.Join(root, "internal", "logger", "logger.go"), appShapeLoggerFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "cache", "cache.go"), appShapeCacheFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "db", "db.go"), appShapeDBFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "config", "config.go"), appShapeConfigFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "metrics", "metrics.go"), appShapeMetricsFile(modulePath)) + writeAppShapeFile(t, filepath.Join(root, "internal", "httpx", "httpx.go"), appShapeHTTPXFile(modulePath)) + if external { + writeAppShapeFile(t, filepath.Join(root, "internal", "extsink", "extsink.go"), appShapeExtSinkFile(modulePath)) + } + writeAppShapeFile(t, filepath.Join(root, "wire", "app.go"), appShapeAppFile(modulePath, features)) + writeAppShapeFile(t, filepath.Join(root, "wire", "wire.go"), appShapeWireFile(modulePath, wireModulePath, features, external)) + for i := 0; i < features; i++ { + writeAppShapeFile(t, filepath.Join(root, "internal", fmt.Sprintf("feature%04d", i), "feature.go"), appShapeFeatureFile(modulePath, wireModulePath, i, depPkgs, external)) + writeAppShapeControllerFile(t, root, i, "base") + } + return filepath.Join(root, "wire") +} + +func writeAppShapeFile(t *testing.T, path, content string) { + t.Helper() + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeAppShapeControllerFile(t *testing.T, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, "internal", fmt.Sprintf("feature%04d", index), "controller.go") + if err := os.WriteFile(path, []byte(appShapeControllerFile("example.com/appbench", index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + +func seedAppShapeExternalGoSum(t *testing.T, root string) { + t.Helper() + const source = "/private/tmp/test/go.sum" + data, err := os.ReadFile(source) + if err != nil { + return + } + if err := os.WriteFile(filepath.Join(root, "go.sum"), data, 0o644); err != nil { + t.Fatalf("write seeded go.sum: %v", err) + } +} + +func resetAppShapeBenchFixture(t *testing.T, pkgDir string, features int) { + t.Helper() + root := filepath.Dir(pkgDir) + for i := 0; i < features; i++ { + writeAppShapeControllerFile(t, root, i, "base") + } +} + +func appShapeGoMod(modulePath, wireModulePath, wireReplaceDir string, external bool) string { + extraRequires := "" + if external { + extraRequires = ` + github.com/alecthomas/kong v1.14.0 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/fsnotify/fsnotify v1.7.0 + github.com/glebarez/sqlite v1.11.0 + github.com/goforj/cache v0.1.5 + github.com/goforj/crypt v1.1.0 + github.com/goforj/env/v2 v2.3.0 + github.com/goforj/httpx v1.1.0 + github.com/goforj/null/v6 v6.0.2 + github.com/goforj/queue v0.1.5 + github.com/goforj/queue/driver/redisqueue v0.1.5 + github.com/goforj/scheduler v1.4.0 + github.com/goforj/storage v0.2.5 + github.com/goforj/storage/driver/localstorage v0.2.5 + github.com/goforj/storage/driver/redisstorage v0.2.5 + github.com/goforj/str v1.3.0 + github.com/google/go-cmp v0.6.0 + github.com/google/subcommands v1.2.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/hibiken/asynq v0.26.0 + github.com/imroc/req/v3 v3.57.0 + github.com/labstack/echo/v4 v4.15.1 + github.com/pmezard/go-difflib v1.0.0 + github.com/redis/go-redis/v9 v9.17.2 + github.com/rs/zerolog v1.34.0 + github.com/shirou/gopsutil/v4 v4.26.2 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.50.0 + golang.org/x/sync v0.19.0 + golang.org/x/sys v0.41.0 + golang.org/x/term v0.40.0 + golang.org/x/tools v0.42.0 + gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.6.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1` + } + return fmt.Sprintf(`module %s + +go 1.26 + +require ( + %s v0.0.0%s +) + +replace %s => %s +`, modulePath, wireModulePath, extraRequires, wireModulePath, wireReplaceDir) +} + +func appShapeLoggerFile(modulePath string) string { + return `package logger + +import ( + "context" + "encoding/json" + "io" + "os" + "sync" + "time" +) + +type Logger struct { + sink io.Writer + mu sync.Mutex +} + +func NewLogger() *Logger { return &Logger{sink: os.Stdout} } + +func (l *Logger) Log(ctx context.Context, msg string, attrs map[string]string) { + l.mu.Lock() + defer l.mu.Unlock() + _, _ = json.Marshal(map[string]any{ + "ctx": ctx != nil, + "msg": msg, + "attrs": attrs, + "time": time.Now().UTC().Format(time.RFC3339Nano), + }) +} +` +} + +func appShapeCacheFile(modulePath string) string { + return `package cache + +type Manager struct{} + +func NewManager() *Manager { return &Manager{} } +` +} + +func appShapeDBFile(modulePath string) string { + return `package db + +import ( + "context" + "database/sql" + "net/url" + "path/filepath" +) + +type DB struct { + driver string + dsn string +} + +func NewDB() *DB { + _ = filepath.Join("var", "lib", "appbench") + _ = sql.LevelDefault + u := &url.URL{Scheme: "postgres", Host: "localhost", Path: "/appbench"} + return &DB{driver: "postgres", dsn: u.String()} +} + +func (db *DB) PingContext(context.Context) error { return nil } +` +} + +func appShapeDepFile(index int) string { + return fmt.Sprintf(`package dep%04d + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "path/filepath" + "strings" +) + +type Value struct { + Name string +} + +func Provide() Value { + sum := sha256.Sum256([]byte(fmt.Sprintf("dep-%%04d", %d))) + return Value{ + Name: filepath.Join("deps", strings.ToLower(hex.EncodeToString(sum[:])))[:16], + } +} +`, index, index) +} + +func appShapeConfigFile(modulePath string) string { + return `package config + +import ( + "encoding/json" + "os" + "strconv" +) + +type Config struct { + Port int + Service string +} + +func NewConfig() *Config { + cfg := &Config{Port: 8080, Service: "appbench"} + if v := os.Getenv("APPBENCH_PORT"); v != "" { + if port, err := strconv.Atoi(v); err == nil { + cfg.Port = port + } + } + _, _ = json.Marshal(cfg) + return cfg +} +` +} + +func appShapeMetricsFile(modulePath string) string { + return `package metrics + +import ( + "expvar" + "fmt" + "sync/atomic" +) + +type Metrics struct { + requests atomic.Int64 + name string +} + +func NewMetrics() *Metrics { + expvar.NewString("appbench_name").Set("appbench") + return &Metrics{name: fmt.Sprintf("appbench_%s", "requests")} +} +` +} + +func appShapeHTTPXFile(modulePath string) string { + return `package httpx + +import ( + "context" + "net/http" + "net/http/httptest" +) + +type Client struct { + client *http.Client +} + +func NewClient() *Client { + req := httptest.NewRequest(http.MethodGet, "http://example.com", nil) + _ = req.WithContext(context.Background()) + return &Client{client: &http.Client{}} +} +` +} + +func appShapeExtSinkFile(modulePath string) string { + return `package extsink + +import ( + "context" + "fmt" + "os" + + _ "github.com/alecthomas/kong" + _ "github.com/charmbracelet/lipgloss" + _ "github.com/charmbracelet/lipgloss/table" + "github.com/fsnotify/fsnotify" + _ "github.com/glebarez/sqlite" + _ "github.com/goforj/cache" + _ "github.com/goforj/crypt" + _ "github.com/goforj/env/v2" + _ "github.com/goforj/httpx" + _ "github.com/goforj/null/v6" + _ "github.com/goforj/queue" + _ "github.com/goforj/queue/driver/redisqueue" + _ "github.com/goforj/scheduler" + _ "github.com/goforj/storage" + _ "github.com/goforj/storage/driver/localstorage" + _ "github.com/goforj/storage/driver/redisstorage" + _ "github.com/goforj/str" + "github.com/google/go-cmp/cmp" + "github.com/google/subcommands" + _ "github.com/google/uuid" + _ "github.com/gorilla/websocket" + _ "github.com/hibiken/asynq" + _ "github.com/imroc/req/v3" + _ "github.com/labstack/echo/v4" + _ "github.com/labstack/echo/v4/middleware" + "github.com/pmezard/go-difflib/difflib" + _ "github.com/redis/go-redis/v9" + _ "github.com/rs/zerolog" + _ "github.com/shirou/gopsutil/v4/cpu" + _ "github.com/shirou/gopsutil/v4/disk" + _ "github.com/shirou/gopsutil/v4/host" + _ "github.com/shirou/gopsutil/v4/mem" + _ "github.com/shirou/gopsutil/v4/net" + _ "github.com/shirou/gopsutil/v4/process" + "golang.org/x/mod/modfile" + _ "golang.org/x/net/http2" + _ "golang.org/x/net/http2/h2c" + "golang.org/x/sync/errgroup" + "golang.org/x/sys/unix" + _ "golang.org/x/term" + "golang.org/x/tools/go/packages" + _ "gopkg.in/yaml.v3" + _ "gorm.io/driver/mysql" + _ "gorm.io/driver/postgres" + _ "gorm.io/gorm" +) + +type Sink struct { + label string +} + +func NewSink() *Sink { + _ = cmp.Equal("a", "b") + _ = difflib.UnifiedDiff{} + _, _ = modfile.Parse("go.mod", []byte("module example.com/appbench"), nil) + _, _ = packages.Load(&packages.Config{Mode: packages.NeedName}, "fmt") + var g errgroup.Group + g.Go(func() error { return nil }) + _ = unix.Getpid() + _ = fsnotify.Event{Name: os.TempDir()} + _ = subcommands.ExitSuccess + return &Sink{label: fmt.Sprintf("sink:%v", context.Background() != nil)} +} +` +} + +func appShapeFeatureFile(modulePath, wireModulePath string, index, depPkgs int, external bool) string { + pkg := fmt.Sprintf("feature%04d", index) + var depImports strings.Builder + var depUse strings.Builder + for i := 0; i < depPkgs; i++ { + depImports.WriteString(fmt.Sprintf("\tdep%04d %q\n", i, fmt.Sprintf("%s/internal/dep%04d", modulePath, i))) + depUse.WriteString(fmt.Sprintf("\t_ = dep%04d.Provide()\n", i)) + } + externalImport := "" + externalArg := "" + externalField := "" + externalUse := "" + if external { + externalImport = fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink") + externalArg = ", sink *extsink.Sink" + externalField = "\tsink *extsink.Sink\n" + externalUse = "\t_ = sink\n" + } + return fmt.Sprintf(`package %s + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + wire %q + %q + %q + %q + %q + %q +%s +%s +) + +type Repo struct { + db *db.DB + config *config.Config + metrics *metrics.Metrics +%s} + +type Service struct { + repo *Repo + logger *logger.Logger + client *httpx.Client +} + +func NewRepo(dbConn *db.DB, cfg *config.Config, m *metrics.Metrics, l *logger.Logger%s) *Repo { + _, _ = json.Marshal(map[string]any{"feature": %d, "service": cfg.Service}) + l.Log(context.Background(), "repo.init", map[string]string{"feature": strconv.Itoa(%d)}) +%s return &Repo{db: dbConn, config: cfg, metrics: m} +} + +func NewService(repo *Repo, l *logger.Logger, client *httpx.Client) *Service { + _, _ = url.Parse(fmt.Sprintf("https://example.com/%%04d", %d)) + _ = time.Second + return &Service{repo: repo, logger: l, client: client} +} + +var Set = wire.NewSet(NewRepo, NewService, NewController) +`, pkg, wireModulePath, modulePath+"/internal/config", modulePath+"/internal/db", modulePath+"/internal/httpx", modulePath+"/internal/logger", modulePath+"/internal/metrics", depImports.String(), externalImport, externalField, externalArg, index, index, depUse.String()+externalUse, index) +} + +func appShapeControllerFile(modulePath string, index int, variant string) string { + pkg := fmt.Sprintf("feature%04d", index) + imports := []string{ + `"context"`, + `"fmt"`, + `"net/http"`, + `"strconv"`, + `"` + modulePath + `/internal/logger"`, + } + if variant == "shape" { + imports = append(imports, `"`+modulePath+`/internal/db"`) + } + if variant == "import" { + imports = append(imports, `"strings"`) + } + bodyLine := "" + switch variant { + case "body": + bodyLine = "\t_ = \"body-edit\"\n" + case "import": + bodyLine = "\t_ = strings.TrimSpace(\" import-edit \")\n" + } + extraField := "" + extraArg := "" + extraInit := "" + if variant == "shape" { + extraField = "\tdb *db.DB\n" + extraArg = ", d *db.DB" + extraInit = "\t\tdb: d,\n" + } + return fmt.Sprintf(`package %s + +import ( + %s +) + +type Controller struct { + logger *logger.Logger + service *Service +%s} + +func NewController(l *logger.Logger, s *Service%s) *Controller { +%s l.Log(context.Background(), "controller.init", map[string]string{"feature": strconv.Itoa(%d)}) + _ = http.MethodGet + _ = fmt.Sprintf("feature-%%d", %d) + return &Controller{ + logger: l, + service: s, +%s } +} +`, pkg, strings.Join(imports, "\n\t"), extraField, extraArg, bodyLine, index, index, extraInit) +} + +func appShapeAppFile(modulePath string, features int) string { + var b strings.Builder + b.WriteString("package wire\n\n") + if features > 0 { + b.WriteString("import (\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\tfeature%04d %q\n", i, fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + } + b.WriteString("type App struct{}\n\n") + b.WriteString("func NewApp(") + for i := 0; i < features; i++ { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(fmt.Sprintf("_ *feature%04d.Controller", i)) + } + b.WriteString(") *App {\n\treturn &App{}\n}\n") + return b.String() +} + +func appShapeWireFile(modulePath, wireModulePath string, features int, external bool) string { + var b strings.Builder + b.WriteString("//go:build wireinject\n\n") + b.WriteString("package wire\n\n") + b.WriteString("import (\n") + b.WriteString(fmt.Sprintf("\twire %q\n", wireModulePath)) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/config")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/db")) + if external { + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/extsink")) + } + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/httpx")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/logger")) + b.WriteString(fmt.Sprintf("\t%q\n", modulePath+"/internal/metrics")) + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t%q\n", fmt.Sprintf("%s/internal/feature%04d", modulePath, i))) + } + b.WriteString(")\n\n") + b.WriteString("func Initialize() *App {\n\twire.Build(\n") + b.WriteString("\t\tconfig.NewConfig,\n") + b.WriteString("\t\tlogger.NewLogger,\n") + b.WriteString("\t\tdb.NewDB,\n") + if external { + b.WriteString("\t\textsink.NewSink,\n") + } + b.WriteString("\t\thttpx.NewClient,\n") + b.WriteString("\t\tmetrics.NewMetrics,\n") + for i := 0; i < features; i++ { + b.WriteString(fmt.Sprintf("\t\tfeature%04d.Set,\n", i)) + } + b.WriteString("\t\tNewApp,\n\t)\n\treturn nil\n}\n") + return b.String() +} + +func runBodyEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "body") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runShapeEditTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchDepFile(t, root, 0, "shape") + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runImportChangeTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runKnownImportToggleTrials(t *testing.T, bin string, imports int, wireModulePath, wireReplaceDir string, trials int) []time.Duration { + t.Helper() + durations := make([]time.Duration, 0, trials) + caches := newBenchCaches(t) + for i := 0; i < trials; i++ { + pkgDir := createImportBenchFixture(t, imports+1, wireModulePath, wireReplaceDir) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + writeImportBenchWireFile(t, root, imports, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports+1, wireModulePath) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + writeImportBenchWireFile(t, root, imports, wireModulePath) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) + } + return durations +} + +func runWireBenchCommand(t *testing.T, bin, pkgDir string, caches benchCaches) time.Duration { t.Helper() - d, _ := runWireBenchCommandOutput(t, bin, pkgDir, home, goCache) + d, _ := runWireBenchCommandOutput(t, bin, pkgDir, caches) return d } -func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, extraArgs ...string) (time.Duration, string) { +func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { t.Helper() args := []string{"gen"} args = append(args, extraArgs...) cmd := exec.Command(bin, args...) cmd.Dir = pkgDir - cmd.Env = append(benchEnv(home, goCache), "WIRE_LOADER_ARTIFACTS=1") + cmd.Env = append(benchEnv(caches.home, caches.goCache), "WIRE_LOADER_ARTIFACTS=1") var stderr bytes.Buffer cmd.Stdout = io.Discard cmd.Stderr = &stderr @@ -220,25 +1025,96 @@ func runWireBenchCommandOutput(t *testing.T, bin, pkgDir, home, goCache string, return time.Since(start), stderr.String() } -func runColdTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { +func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prewarm go cache in %s: %v\n%s", pkgDir, err, output) + } +} + +func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { + t.Helper() + prepareBenchModule(t, pkgDir, caches) + cmd := exec.Command("go", "list", "-deps", "-json", "./...") + cmd.Dir = pkgDir + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go list graph counts in %s: %v\n%s", pkgDir, err, output) + } + dec := json.NewDecoder(bytes.NewReader(output)) + seen := make(map[string]struct{}) + var counts benchGraphCounts + for { + var pkg struct { + ImportPath string + Standard bool + } + if err := dec.Decode(&pkg); err != nil { + if err == io.EOF { + break + } + t.Fatalf("decode graph counts for %s: %v", pkgDir, err) + } + if pkg.ImportPath == "" { + continue + } + if _, ok := seen[pkg.ImportPath]; ok { + continue + } + seen[pkg.ImportPath] = struct{}{} + switch { + case pkg.Standard: + counts.stdlib++ + case pkg.ImportPath == modulePath || strings.HasPrefix(pkg.ImportPath, modulePath+"/"): + counts.local++ + default: + counts.external++ + } + } + return counts +} + +func prepareBenchModule(t *testing.T, pkgDir string, caches benchCaches) { + t.Helper() + marker := filepath.Join(filepath.Dir(pkgDir), ".bench-module-ready") + if _, err := os.Stat(marker); err == nil { + return + } + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = filepath.Dir(pkgDir) + cmd.Env = benchEnv(caches.home, caches.goCache) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("prepare bench module in %s: %v\n%s", pkgDir, err, output) + } + if err := os.WriteFile(marker, []byte("ok\n"), 0o644); err != nil { + t.Fatalf("write module marker %s: %v", marker, err) + } +} + +func runColdTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) for i := 0; i < trials; i++ { - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + prewarmGoBenchCache(t, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) } return durations } -func runWarmTrials(t *testing.T, bin, pkgDir string, trials int) []time.Duration { +func runWarmTrials(t *testing.T, bin, pkgDir string, caches benchCaches, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) for i := 0; i < trials; i++ { - home := t.TempDir() - goCache := filepath.Join(t.TempDir(), "gocache") - _ = runWireBenchCommand(t, bin, pkgDir, home, goCache) - durations = append(durations, runWireBenchCommand(t, bin, pkgDir, home, goCache)) + prewarmGoBenchCache(t, pkgDir, caches) + _ = runWireBenchCommand(t, bin, pkgDir, caches) + durations = append(durations, runWireBenchCommand(t, bin, pkgDir, caches)) } return durations } @@ -262,6 +1138,7 @@ func benchEnv(home, goCache string) []string { "HOME="+home, "GOCACHE="+goCache, "GOMODCACHE=/tmp/gomodcache", + "GOSUMDB=off", ) return env } @@ -304,6 +1181,33 @@ func importBenchWireFile(imports int, wireModulePath string) string { return b.String() } +func importBenchDepFile(i int, variant string) string { + switch variant { + case "body": + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T {\n\t_ = \"body-edit\"\n\treturn &T{}\n}\n", i) + case "shape": + return fmt.Sprintf("package dep%04d\n\ntype T struct{ Extra int }\n\nfunc Provide() *T { return &T{} }\n", i) + default: + return fmt.Sprintf("package dep%04d\n\ntype T struct{}\n\nfunc Provide() *T { return &T{} }\n", i) + } +} + +func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModulePath string) { + t.Helper() + path := filepath.Join(root, "app", "wire.go") + if err := os.WriteFile(path, []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { + t.Fatal(err) + } +} + +func writeImportBenchDepFile(t *testing.T, root string, index int, variant string) { + t.Helper() + path := filepath.Join(root, fmt.Sprintf("dep%04d", index), "dep.go") + if err := os.WriteFile(path, []byte(importBenchDepFile(index, variant)), 0o644); err != nil { + t.Fatal(err) + } +} + func printImportBenchTable(t *testing.T, rows []importBenchRow) { t.Helper() fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") @@ -322,6 +1226,104 @@ func printImportBenchTable(t *testing.T, rows []importBenchRow) { fmt.Println("+-----------+-----------+--------------+-------------------+--------------+-------------------+") } +func printImportScenarioBenchTable(t *testing.T, rows []importBenchScenarioRow) { + t.Helper() + profileWidth := len("profile") + localWidth := len("local") + stdlibWidth := len("stdlib") + externalWidth := len("external") + changeTypeWidth := len("change type") + stockWidth := len("stock") + currentWidth := len("current") + speedupWidth := len("speedup") + for _, row := range rows { + profileWidth = maxInt(profileWidth, len(row.profile)) + localWidth = maxInt(localWidth, len(fmt.Sprintf("%d", row.localCount))) + stdlibWidth = maxInt(stdlibWidth, len(fmt.Sprintf("%d", row.stdlibCount))) + externalWidth = maxInt(externalWidth, len(fmt.Sprintf("%d", row.externalCount))) + changeTypeWidth = maxInt(changeTypeWidth, len(row.name)) + stockWidth = maxInt(stockWidth, len(formatMs(row.stock))) + currentWidth = maxInt(currentWidth, len(formatMs(row.current))) + speedupWidth = maxInt(speedupWidth, len(formatSpeedup(row.stock, row.current))) + } + sep := fmt.Sprintf("+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+-%s-+", + strings.Repeat("-", profileWidth), + strings.Repeat("-", localWidth), + strings.Repeat("-", stdlibWidth), + strings.Repeat("-", externalWidth), + strings.Repeat("-", changeTypeWidth), + strings.Repeat("-", stockWidth), + strings.Repeat("-", currentWidth), + strings.Repeat("-", speedupWidth), + ) + fmt.Println(sep) + fmt.Printf("| %*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, "profile", + localWidth, "local", + stdlibWidth, "stdlib", + externalWidth, "external", + changeTypeWidth, "change type", + stockWidth, "stock", + currentWidth, "current", + speedupWidth, "speedup", + ) + fmt.Println(sep) + for _, row := range rows { + fmt.Printf("| %*s | %-*d | %-*d | %-*d | %-*s | %-*s | %-*s | %-*s |\n", + profileWidth, row.profile, + localWidth, row.localCount, + stdlibWidth, row.stdlibCount, + externalWidth, row.externalCount, + changeTypeWidth, row.name, + stockWidth, formatMs(row.stock), + currentWidth, formatMs(row.current), + speedupWidth, formatSpeedup(row.stock, row.current), + ) + } + fmt.Println(sep) + fmt.Println() + fmt.Println("change types:") + fmt.Println(" cold run: first wire gen on a fresh Wire cache for that repo shape") + fmt.Println(" unchanged rerun: run wire gen again without changing any files") + fmt.Println(" body-only local edit: change local function body/content without changing imports, types, or constructor signatures") + fmt.Println(" shape change: change local provider/type shape such as constructor params, fields, or return shape") + fmt.Println(" import change: add or remove a local import, which can change discovered package shape") + fmt.Println(" known import toggle: switch back to a previously seen import/shape state in the same repo") +} + +func maxInt(a, b int) int { + if a > b { + return a + } + return b +} + +func printScenarioTimingLines(output string) { + for _, line := range strings.Split(output, "\n") { + if !strings.Contains(line, "wire: timing:") { + continue + } + if strings.Contains(line, "loader.custom.root.discovery=") || + strings.Contains(line, "loader.discovery.") || + strings.Contains(line, "load.packages.load=") || + strings.Contains(line, "load.debug") || + strings.Contains(line, "loader.custom.typed.artifact_read=") || + strings.Contains(line, "loader.custom.typed.artifact_decode=") || + strings.Contains(line, "loader.custom.typed.artifact_import_link=") || + strings.Contains(line, "loader.custom.typed.artifact_write=") || + strings.Contains(line, "loader.custom.typed.root_load.wall=") || + strings.Contains(line, "loader.custom.typed.discovery.wall=") || + strings.Contains(line, "loader.custom.typed.artifact_hits=") || + strings.Contains(line, "loader.custom.typed.artifact_misses=") || + strings.Contains(line, "loader.custom.typed.artifact_writes=") || + strings.Contains(line, "generate.package.") || + strings.Contains(line, "wire.Generate=") || + strings.Contains(line, "total=") { + fmt.Println(line) + } + } +} + func formatMs(d time.Duration) string { return fmt.Sprintf("%.1fms", float64(d)/float64(time.Millisecond)) } @@ -340,7 +1342,9 @@ func TestImportBenchFixtureGenerates(t *testing.T) { } bin := buildWireBinary(t, repoRoot, "fixture-wire") fixture := createImportBenchFixture(t, 10, currentWireModulePath, repoRoot) - _ = runWireBenchCommand(t, bin, fixture, t.TempDir(), filepath.Join(t.TempDir(), "gocache")) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, fixture, caches) + _ = runWireBenchCommand(t, bin, fixture, caches) } func TestImportBenchUsesStockArchive(t *testing.T) { diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index e1c1f97..2eb98c2 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,10 +12,12 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table + scripts/import-benchmarks.sh scenarios scripts/import-benchmarks.sh breakdown Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. + scenarios Print the stock-vs-current change-type scenario table. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -24,6 +26,9 @@ case "${1:-}" in table) WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; + scenarios) + WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + ;; breakdown) WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v ;; From ba95466e9ebc18668d8d729ffdb6c37b2fba9cf3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:17:52 -0500 Subject: [PATCH 18/82] fix: ci --- cmd/wire/cache_cmd_test.go | 12 ++-- cmd/wire/check_cmd.go | 4 +- cmd/wire/diff_cmd.go | 6 +- cmd/wire/main.go | 10 ++-- cmd/wire/show_cmd.go | 4 +- internal/loader/custom.go | 90 +++++++++++++++--------------- internal/loader/discovery_cache.go | 18 +++--- internal/wire/import_bench_test.go | 17 ++++-- 8 files changed, 83 insertions(+), 78 deletions(-) diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go index c0c74c1..83924e2 100644 --- a/cmd/wire/cache_cmd_test.go +++ b/cmd/wire/cache_cmd_test.go @@ -10,9 +10,9 @@ func TestWireCacheTargetsDefault(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") got := wireCacheTargets(nil, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), - "output-cache": filepath.Join(base, "wire", "output-cache"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), } if len(got) != len(want) { @@ -50,9 +50,9 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { } got := wireCacheTargets(env, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "loader"), - "output-cache": filepath.Join(base, "output"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), "semantic-artifacts": filepath.Join(base, "semantic"), } for _, target := range got { diff --git a/cmd/wire/check_cmd.go b/cmd/wire/check_cmd.go index 897bec2..7857437 100644 --- a/cmd/wire/check_cmd.go +++ b/cmd/wire/check_cmd.go @@ -26,8 +26,8 @@ import ( ) type checkCmd struct { - tags string - profile profileFlags + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/cmd/wire/diff_cmd.go b/cmd/wire/diff_cmd.go index 5aad2f1..592cced 100644 --- a/cmd/wire/diff_cmd.go +++ b/cmd/wire/diff_cmd.go @@ -29,9 +29,9 @@ import ( ) type diffCmd struct { - headerFile string - tags string - profile profileFlags + headerFile string + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/cmd/wire/main.go b/cmd/wire/main.go index 515673c..c13b850 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -36,11 +36,11 @@ import ( ) const ( - ansiRed = "\033[1;31m" - ansiGreen = "\033[1;32m" - ansiReset = "\033[0m" - successSig = "✓ " - errorSig = "x " + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " maxLoggedErrorLines = 5 ) diff --git a/cmd/wire/show_cmd.go b/cmd/wire/show_cmd.go index 10c737f..5a81b29 100644 --- a/cmd/wire/show_cmd.go +++ b/cmd/wire/show_cmd.go @@ -34,8 +34,8 @@ import ( ) type showCmd struct { - tags string - profile profileFlags + tags string + profile profileFlags } // Name returns the subcommand name. diff --git a/internal/loader/custom.go b/internal/loader/custom.go index dd2b9e0..52ebf4e 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -71,21 +71,21 @@ type customValidator struct { } type customTypedGraphLoader struct { - workspace string - ctx context.Context - env []string - fset *token.FileSet - meta map[string]*packageMeta - targets map[string]struct{} - parseFile ParseFileFunc - packages map[string]*packages.Package - typesPkgs map[string]*types.Package - importer types.Importer - loading map[string]bool - isLocalCache map[string]bool - localSemanticOK map[string]bool - artifactPrefetch map[string]artifactPrefetchEntry - stats typedLoadStats + workspace string + ctx context.Context + env []string + fset *token.FileSet + meta map[string]*packageMeta + targets map[string]struct{} + parseFile ParseFileFunc + packages map[string]*packages.Package + typesPkgs map[string]*types.Package + importer types.Importer + loading map[string]bool + isLocalCache map[string]bool + localSemanticOK map[string]bool + artifactPrefetch map[string]artifactPrefetchEntry + stats typedLoadStats } type artifactPrefetchEntry struct { @@ -263,21 +263,21 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: map[string]struct{}{req.Package: {}}, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, } prefetchStart := time.Now() l.prefetchArtifacts() @@ -345,21 +345,21 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, + workspace: detectModuleRoot(req.WD), + ctx: ctx, + env: append([]string(nil), req.Env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: req.ParseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, } prefetchStart := time.Now() l.prefetchArtifacts() diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 4ec9a12..3b9fe46 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -14,15 +14,15 @@ import ( ) type discoveryCacheEntry struct { - Version int - WD string - Tags string - Patterns []string - NeedDeps bool - Workspace string - Meta map[string]*packageMeta - Global []discoveryFileMeta - LocalPkgs []discoveryLocalPackage + Version int + WD string + Tags string + Patterns []string + NeedDeps bool + Workspace string + Meta map[string]*packageMeta + Global []discoveryFileMeta + LocalPkgs []discoveryLocalPackage } type discoveryLocalPackage struct { diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index aff71e2..e2f5900 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -26,10 +26,10 @@ const ( ) type importBenchRow struct { - imports int - stockCold time.Duration - currentCold time.Duration - currentWarm time.Duration + imports int + stockCold time.Duration + currentCold time.Duration + currentWarm time.Duration } type importBenchScenarioRow struct { @@ -520,7 +520,7 @@ func appShapeGoMod(modulePath, wireModulePath, wireReplaceDir string, external b } return fmt.Sprintf(`module %s -go 1.26 +go 1.19 require ( %s v0.0.0%s @@ -1146,7 +1146,7 @@ func benchEnv(home, goCache string) []string { func importBenchGoMod(wireModulePath, wireReplaceDir string) string { return fmt.Sprintf(`module example.com/importbench -go 1.26 +go 1.19 require %s v0.0.0 @@ -1352,6 +1352,11 @@ func TestImportBenchUsesStockArchive(t *testing.T) { if err != nil { t.Fatal(err) } + check := exec.Command("git", "cat-file", "-e", stockWireCommit+"^{commit}") + check.Dir = repoRoot + if err := check.Run(); err != nil { + t.Skipf("stock archive commit %s not available in checkout", stockWireCommit) + } stockDir := extractStockWire(t, repoRoot, stockWireCommit) if _, err := os.Stat(filepath.Join(stockDir, "cmd", "wire", "main.go")); err != nil { t.Fatalf("stock archive missing cmd/wire: %v", err) From 4b312b510154d5b87dac53f846b13e14a1d5d0e9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:20:15 -0500 Subject: [PATCH 19/82] fix: ci --- internal/loader/loader_test.go | 25 +++++++++++++++++-------- internal/loader/timing.go | 6 ++---- internal/wire/timing.go | 7 +++---- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 0c871e0..5f734b7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -248,14 +248,23 @@ func TestValidateTouchedPackagesAutoReportsFallbackDetail(t *testing.T) { if err != nil { t.Fatalf("ValidateTouchedPackages(auto) error = %v", err) } - if got.Backend != ModeFallback { - t.Fatalf("backend = %q, want %q", got.Backend, ModeFallback) - } - if got.FallbackReason != FallbackReasonCustomUnsupported { - t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) - } - if got.FallbackDetail != "metadata fingerprint mismatch" { - t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + switch got.Backend { + case ModeCustom: + if got.FallbackReason != FallbackReasonNone { + t.Fatalf("fallback reason = %q, want empty for custom backend", got.FallbackReason) + } + if got.FallbackDetail != "" { + t.Fatalf("fallback detail = %q, want empty for custom backend", got.FallbackDetail) + } + case ModeFallback: + if got.FallbackReason != FallbackReasonCustomUnsupported { + t.Fatalf("fallback reason = %q, want %q", got.FallbackReason, FallbackReasonCustomUnsupported) + } + if got.FallbackDetail != "metadata fingerprint mismatch" { + t.Fatalf("fallback detail = %q, want %q", got.FallbackDetail, "metadata fingerprint mismatch") + } + default: + t.Fatalf("backend = %q, want %q or %q", got.Backend, ModeCustom, ModeFallback) } } diff --git a/internal/loader/timing.go b/internal/loader/timing.go index 4b902db..1ae9ccd 100644 --- a/internal/loader/timing.go +++ b/internal/loader/timing.go @@ -3,7 +3,6 @@ package loader import ( "context" "fmt" - "log" "time" ) @@ -49,8 +48,7 @@ func logInt(ctx context.Context, label string, v int) { } func debugf(ctx context.Context, format string, args ...interface{}) { - if timing(ctx) == nil { - return + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) } - log.Printf("timing: "+format, args...) } diff --git a/internal/wire/timing.go b/internal/wire/timing.go index d83754b..84c9022 100644 --- a/internal/wire/timing.go +++ b/internal/wire/timing.go @@ -16,7 +16,7 @@ package wire import ( "context" - "log" + "fmt" "time" ) @@ -52,8 +52,7 @@ func logTiming(ctx context.Context, label string, start time.Time) { } func debugf(ctx context.Context, format string, args ...interface{}) { - if timing(ctx) == nil { - return + if t := timing(ctx); t != nil { + t(fmt.Sprintf(format, args...), 0) } - log.Printf("timing: "+format, args...) } From 433ffc775b43d899f6cbae4bba2bc70a212b3c61 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:26:29 -0500 Subject: [PATCH 20/82] fix: windows tmpdir issue --- internal/loader/loader_test.go | 12 ++++++++---- internal/wire/import_bench_test.go | 9 ++++++++- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 5f734b7..27a129a 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -1118,12 +1118,16 @@ func normalizeErrorPos(pos string) string { if pos == "" || pos == "-" { return pos } - parts := strings.Split(pos, ":") - if len(parts) < 2 { + last := strings.LastIndex(pos, ":") + if last == -1 { return shortenComparablePath(normalizePathForCompare(pos)) } - path := shortenComparablePath(normalizePathForCompare(parts[0])) - return strings.Join(append([]string{path}, parts[1:]...), ":") + prev := strings.LastIndex(pos[:last], ":") + if prev == -1 { + return shortenComparablePath(normalizePathForCompare(pos)) + } + path := shortenComparablePath(normalizePathForCompare(pos[:prev])) + return path + pos[prev:] } func expandSummaryDiagnostics(msg string) []string { diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index e2f5900..4f5dd82 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -1137,12 +1137,19 @@ func benchEnv(home, goCache string) []string { env = append(env, "HOME="+home, "GOCACHE="+goCache, - "GOMODCACHE=/tmp/gomodcache", + "GOMODCACHE="+benchModCache(), "GOSUMDB=off", ) return env } +func benchModCache() string { + if path := os.Getenv("GOMODCACHE"); path != "" { + return path + } + return filepath.Join(os.TempDir(), "gomodcache") +} + func importBenchGoMod(wireModulePath, wireReplaceDir string) string { return fmt.Sprintf(`module example.com/importbench From d941179194e76e3891c9f746d8492e54847b9597 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 15 Mar 2026 21:32:44 -0500 Subject: [PATCH 21/82] fix: windows bench executable path --- internal/wire/import_bench_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 4f5dd82..cd38190 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "runtime" "strings" "testing" "time" @@ -314,6 +315,9 @@ func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, ex func buildWireBinary(t *testing.T, dir, name string) string { t.Helper() + if runtime.GOOS == "windows" && filepath.Ext(name) != ".exe" { + name += ".exe" + } out := filepath.Join(t.TempDir(), name) cmd := exec.Command("go", "build", "-o", out, "./cmd/wire") cmd.Dir = dir From 88c86347dd3e4aa6fa656e558df310e9cf037ac6 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 14:49:15 -0500 Subject: [PATCH 22/82] fix(loader): strengthen artifact keys for replaced external modules --- internal/loader/artifact_cache.go | 23 +++++ internal/loader/loader_test.go | 160 ++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index 42293cb..e920d5a 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -72,6 +72,29 @@ func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { if !isLocal { sum.Write([]byte(meta.Export)) sum.Write([]byte{'\n'}) + if meta.Export != "" { + info, err := os.Stat(meta.Export) + if err != nil { + return "", err + } + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } else { + for _, name := range metaFiles(meta) { + info, err := os.Stat(name) + if err != nil { + return "", err + } + sum.Write([]byte(name)) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) + sum.Write([]byte{'\n'}) + sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte{'\n'}) + } + } if meta.Error != nil { sum.Write([]byte(meta.Error.Err)) sum.Write([]byte{'\n'}) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 27a129a..2010f56 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -702,6 +702,166 @@ func TestLoadTypedPackageGraphCustomExternalArtifactCacheReportsHits(t *testing. } } +func TestLoadTypedPackageGraphCustomArtifactCacheReplacedModuleSourceChange(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\treturn dep.New()", + "}", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + load := func(mode Mode) (*LazyLoadResult, error) { + l := New() + return l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: appRoot, + Env: env, + Package: "example.com/app/app", + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + } + + first, err := load(ModeCustom) + if err != nil { + t.Fatalf("first LoadTypedPackageGraph(custom) error = %v", err) + } + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar l dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(l)", + "}", + "", + }, "\n")) + + custom, err := load(ModeCustom) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(custom) error = %v", err) + } + if len(custom.Packages) != 1 { + t.Fatalf("second custom packages len = %d, want 1", len(custom.Packages)) + } + if got := comparableErrors(custom.Packages[0].Errors); len(got) != 0 { + t.Fatalf("second custom load returned errors: %v", got) + } + + fallback, err := load(ModeFallback) + if err != nil { + t.Fatalf("second LoadTypedPackageGraph(fallback) error = %v", err) + } + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoaderArtifactKeyExternalChangesWhenExportFileChanges(t *testing.T) { + exportPath := filepath.Join(t.TempDir(), "dep.a") + writeTestFile(t, exportPath, "first export payload") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + Export: exportPath, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, exportPath, "second export payload with different contents") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after export file update: %q", first) + } +} + +func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testing.T) { + sourcePath := filepath.Join(t.TempDir(), "dep.go") + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"first\"\n") + + meta := &packageMeta{ + ImportPath: "example.com/dep", + Name: "dep", + GoFiles: []string{sourcePath}, + CompiledGoFiles: []string{sourcePath}, + } + + first, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(first) error = %v", err) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, sourcePath, "package dep\n\nconst Name = \"second\"\n") + + second, err := loaderArtifactKey(meta, false) + if err != nil { + t.Fatalf("loaderArtifactKey(second) error = %v", err) + } + + if first == second { + t.Fatalf("loaderArtifactKey did not change after external source update without export data: %q", first) + } +} + func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { From 09073ee95aff218e7a4feb21d0ba69936da135bf Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:12:02 -0500 Subject: [PATCH 23/82] test(loader): harden cache invalidation and discovery parity coverage --- internal/loader/custom.go | 18 +- internal/loader/discovery.go | 2 +- internal/loader/loader_test.go | 1917 ++++++++++++++++++++++++++++++++ 3 files changed, 1935 insertions(+), 2 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 52ebf4e..10f8c79 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -641,7 +641,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er return typed, nil } if len(dep.Errors) > 0 { - return nil, unsupportedError{reason: "lazy-load dependency has errors"} + return nil, dependencyImportError(dep) } return nil, unsupportedError{reason: "missing typed lazy-load dependency"} }), @@ -1100,6 +1100,22 @@ func toPackagesError(fset *token.FileSet, err error) packages.Error { } } +func dependencyImportError(pkg *packages.Package) error { + if pkg == nil { + return unsupportedError{reason: "lazy-load dependency has errors"} + } + if pkg.Name == "" { + return fmt.Errorf("invalid package name: %q", pkg.Name) + } + for _, err := range pkg.Errors { + if strings.TrimSpace(err.Msg) == "" { + continue + } + return fmt.Errorf("%s", err.Msg) + } + return unsupportedError{reason: "lazy-load dependency has errors"} +} + type importerFunc func(path string) (*types.Package, error) func (f importerFunc) Import(path string) (*types.Package, error) { return f(path) } diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 22b34d3..0e7e69c 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -46,7 +46,7 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, return cached, nil } logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) - args := []string{"list", "-json", "-e", "-compiled"} + args := []string{"list", "-json", "-e", "-compiled", "-export"} if req.NeedDeps { args = append(args, "-deps") } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 2010f56..1cb080d 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -15,13 +15,16 @@ package loader import ( + "archive/zip" "bytes" "context" + "fmt" "go/ast" "go/parser" "go/token" "go/types" "os" + "os/exec" "path/filepath" "sort" "strconv" @@ -805,6 +808,1522 @@ func TestLoadTypedPackageGraphCustomArtifactCacheReplacedModuleSourceChange(t *t compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) } +func TestDiscoveryCacheInvalidatesOnGoModResolutionChange(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceReplaceTargetIsolation(t *testing.T) { + cacheHome := t.TempDir() + artifactDir := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + depOneRoot := filepath.Join(repoOne, "depmod") + appOneRoot := filepath.Join(repoOne, "appmod") + depTwoRoot := filepath.Join(repoTwo, "depmod") + appTwoRoot := filepath.Join(repoTwo, "appmod") + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appOneRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appTwoRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+cacheHome, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + warm := loadTypedPackageGraphForTest(t, appOneRoot, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appTwoRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomTransitiveShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T { return b.New() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() any { return a.New() }", + "", + }, "\n")) + + first := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "type T struct{}", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func New() *b.T {", + "\tvar logger b.Logger = b.NoopLogger{}", + "\treturn b.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, os.Environ(), "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplacePathSwitchInvalidatesCaches(t *testing.T) { + root := t.TempDir() + depOneRoot := filepath.Join(root, "dep-one") + depTwoRoot := filepath.Join(root, "dep-two") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depOneRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depOneRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"one\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depTwoRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depTwoRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func New() string { return strings.TrimSpace(\" two \") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depOneRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depTwoRoot, + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomDiscoveryCacheReplacedSiblingOutsideWorkspace(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + rootLoad := loadRootGraphForTest(t, appRoot, env, []string{"./app"}, ModeCustom) + if rootLoad.Discovery == nil { + t.Fatal("expected discovery snapshot from custom root load") + } + + first := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphWithDiscoveryForTest(t, appRoot, env, "example.com/app/app", ModeCustom, rootLoad.Discovery) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestDiscoveryCacheInvalidatesOnGeneratedFileSetChange(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "zz_generated.go"), strings.Join([]string{ + "package dep", + "", + "func Generated() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() + dep.Generated() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomBodyOnlyEditWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"before\")", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string {", + "\treturn fmt.Sprint(\"after\")", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestLoadTypedPackageGraphCustomReplaceNestedModuleParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(appRoot, "third_party", "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => ./third_party/depmod", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceChainParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + midRoot := filepath.Join(root, "midmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(midRoot, "go.mod"), "module example.com/mid\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require (", + "\texample.com/dep v0.0.0", + "\texample.com/mid v0.0.0", + ")", + "", + "replace example.com/dep => " + depRoot, + "replace example.com/mid => " + midRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/mid\"", + "", + "func Use() string { return mid.Use() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(midRoot, "mid.go"), strings.Join([]string{ + "package mid", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/mid", false) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomGoWorkWorkspaceParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.work"), strings.Join([]string{ + "go 1.19", + "", + "use (", + "\t./appmod", + "\t./depmod", + ")", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return strings.TrimSpace(\" ok \") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomCrossWorkspaceModuleIsolation(t *testing.T) { + cacheHome := t.TempDir() + repoOne := filepath.Join(t.TempDir(), "repo-one") + repoTwo := filepath.Join(t.TempDir(), "repo-two") + + writeTestFile(t, filepath.Join(repoOne, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoOne, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Message() string { return \"one\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoOne, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(repoTwo, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(repoTwo, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"strings\"", + "", + "func Message() string { return strings.ToUpper(\"two\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(repoTwo, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+cacheHome) + warm := loadTypedPackageGraphForTest(t, repoOne, env, "example.com/app/app", ModeCustom) + if len(warm.Packages) != 1 || len(warm.Packages[0].Errors) != 0 { + t.Fatalf("repo one warm custom load returned errors: %+v", warm.Packages) + } + + custom := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, repoTwo, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/dep", true) +} + +func TestDiscoveryCacheInvalidatesOnLocalImportChangeEndToEnd(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func Base() string { return \"base\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "extra", "extra.go"), strings.Join([]string{ + "package extra", + "", + "func Value() string { return \"extra\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.Base() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"example.com/app/extra\"", + "", + "func Base() string { return extra.Value() }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomLocalShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type T struct{}", + "", + "func New() *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "type T struct{}", + "", + "func New(Config) *T { return &T{} }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() *dep.T { return dep.New(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomTransitiveBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "a", "a.go"), strings.Join([]string{ + "package a", + "", + "import \"example.com/app/b\"", + "", + "func Message() string { return b.Message() }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/a\"", + "", + "func Init() string { return a.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "b", "b.go"), strings.Join([]string{ + "package b", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/a", true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/app/b", true) +} + +func TestLoadTypedPackageGraphCustomKnownShapeToggleWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct { Name string }", + "", + "func New(Config) string { return \"config\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New(dep.Config{Name: \"a\"}) }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"logger\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomNewShapeWarmParity(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), "module example.com/app\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "dep", "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Config struct{}", + "", + "func New() string { return \"ok\" }", + "", + "func NewWithConfig(Config) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/dep\"", + "", + "func Init() string { return dep.NewWithConfig(dep.Config{}) }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + comparePackageGraphs(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetBodyOnlyWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"before\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"after\") }", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetShapeChangeWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomFixtureAppWarmMutationParity(t *testing.T) { + root := t.TempDir() + appRoot := filepath.Join(root, "appmod") + depRoot := filepath.Join(root, "depmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "base", "base.go"), strings.Join([]string{ + "package base", + "", + "import \"fmt\"", + "", + "func Prefix() string { return fmt.Sprint(\"base:\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\treturn base.Prefix() + dep.Message() + gen.Value()", + "}", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/app/feature\"", + "", + "func Init() string { return feature.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + coldCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(coldCustom.Packages) != 1 || len(coldCustom.Packages[0].Errors) != 0 { + t.Fatalf("cold custom load returned errors: %+v", coldCustom.Packages) + } + coldFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, coldCustom.Packages, coldFallback.Packages, true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, coldCustom.Packages, coldFallback.Packages, "example.com/dep", false) + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep2\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "gen", "zz_generated.go"), strings.Join([]string{ + "package gen", + "", + "func Value() string { return \"generated2\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "feature", "feature.go"), strings.Join([]string{ + "package feature", + "", + "import (", + "\t\"example.com/app/base\"", + "\t\"example.com/app/gen\"", + "\t\"example.com/dep\"", + ")", + "", + "func Message() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn base.Prefix() + dep.Message(logger) + gen.Value()", + "}", + "", + }, "\n")) + + warmCustom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + warmFallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, warmCustom.Packages, warmFallback.Packages, true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/feature", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/app/gen", true) + comparePackageByPath(t, warmCustom.Packages, warmFallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomSequentialMutationsParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "helper", "helper.go"), strings.Join([]string{ + "package helper", + "", + "func Prefix() string { return \"prefix:\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + + env := append(os.Environ(), "HOME="+homeDir) + + assertParity := func() { + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) + } + + initial := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(initial.Packages) != 1 || len(initial.Packages[0].Errors) != 0 { + t.Fatalf("initial custom load returned errors: %+v", initial.Packages) + } + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep-body\") }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string { return helper.Prefix() + dep.Message() }", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func Message(Logger) string { return fmt.Sprint(\"dep-shape\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import (", + "\t\"example.com/app/helper\"", + "\t\"example.com/dep\"", + ")", + "", + "func Init() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn helper.Prefix() + dep.Message(logger)", + "}", + "", + }, "\n")) + assertParity() + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "import \"fmt\"", + "", + "func Message() string { return fmt.Sprint(\"dep\") }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Init() string { return dep.Message() }", + "", + }, "\n")) + assertParity() +} + +func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { + root := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require github.com/google/go-cmp v0.6.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "go.sum"), strings.Join([]string{ + "github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=", + "github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"github.com/google/go-cmp/cmp\"", + "", + "func Init() string { return cmp.Diff(\"a\", \"b\") }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY=off", + "GONOSUMDB=*", + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + ) + + first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + if got := comparableErrors(first.Packages[0].Errors); len(got) != 0 { + t.Fatalf("first custom load returned errors: %v", got) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.sum"), "") + + custom := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + +func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T) { + root := t.TempDir() + proxyDir := t.TempDir() + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.1.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nimport \"strings\"\n\nfunc Version() string { return strings.TrimSpace(\"v1.1.0\") }\n", + }) + + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.0.0", + "", + }, "\n")) + writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/extdep/pkg\"", + "", + "func Init() string { return pkg.Version() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOPROXY=file://"+proxyDir, + "GOSUMDB=off", + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + runGoModTidyForTest(t, root, env) + + first := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + firstDep := collectGraph(first.Packages)["example.com/extdep/pkg"] + if firstDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg") + } + if !containsPathSubstring(firstDep.CompiledGoFiles, "example.com/extdep@v1.0.0") { + t.Fatalf("first dependency files = %v, want version v1.0.0", firstDep.CompiledGoFiles) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/extdep v1.1.0", + "", + }, "\n")) + runGoModTidyForTest(t, root, env) + + custom := loadPackagesForTest(t, root, env, []string{"./app"}, ModeCustom) + fallback := loadPackagesForTest(t, root, env, []string{"./app"}, ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/extdep/pkg", false) + + secondDep := collectGraph(custom.Packages)["example.com/extdep/pkg"] + if secondDep == nil { + t.Fatal("expected dependency package for example.com/extdep/pkg after version change") + } + if !containsPathSubstring(secondDep.CompiledGoFiles, "example.com/extdep@v1.1.0") { + t.Fatalf("second dependency files = %v, want version v1.1.0", secondDep.CompiledGoFiles) + } +} + func TestLoaderArtifactKeyExternalChangesWhenExportFileChanges(t *testing.T) { exportPath := filepath.Join(t.TempDir(), "dep.a") writeTestFile(t, exportPath, "first export payload") @@ -862,6 +2381,225 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } +func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), "package app\n\nimport \"example.com/dep\"\n\nfunc Use() string { return dep.New() }\n") + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: append(os.Environ(), "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache"), + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil { + t.Fatal("expected metadata for example.com/dep") + } + if depMeta.Export == "" { + t.Fatalf("expected export data path for replaced module metadata: %+v", depMeta) + } +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + artifactDir := t.TempDir() + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + loaderArtifactEnv+"=1", + loaderArtifactDirEnv+"="+artifactDir, + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList() error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export == "" { + t.Fatalf("expected export-backed metadata for example.com/dep: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 || len(first.Packages[0].Errors) != 0 { + t.Fatalf("first custom load returned errors: %+v", first.Packages) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "type Logger interface { Log(string) }", + "", + "type NoopLogger struct{}", + "", + "func (NoopLogger) Log(string) {}", + "", + "func New(Logger) string { return \"ok\" }", + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string {", + "\tvar logger dep.Logger = dep.NoopLogger{}", + "\treturn dep.New(logger)", + "}", + "", + }, "\n")) + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) + comparePackageByPath(t, custom.Packages, fallback.Packages, "example.com/dep", false) +} + +func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t *testing.T) { + root := t.TempDir() + depRoot := filepath.Join(root, "depmod") + appRoot := filepath.Join(root, "appmod") + homeDir := t.TempDir() + + writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "//go:build never", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + writeTestFile(t, filepath.Join(appRoot, "go.mod"), strings.Join([]string{ + "module example.com/app", + "", + "go 1.19", + "", + "require example.com/dep v0.0.0", + "", + "replace example.com/dep => " + depRoot, + "", + }, "\n")) + writeTestFile(t, filepath.Join(appRoot, "app", "app.go"), strings.Join([]string{ + "package app", + "", + "import \"example.com/dep\"", + "", + "func Use() string { return dep.New() }", + "", + }, "\n")) + + env := append(os.Environ(), + "HOME="+homeDir, + "GOCACHE=/tmp/gocache", + "GOMODCACHE=/tmp/gomodcache", + ) + + meta, err := runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(first) error = %v", err) + } + depMeta := meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for incomplete replaced module: %+v", depMeta) + } + + first := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + if len(first.Packages) != 1 { + t.Fatalf("first custom packages len = %d, want 1", len(first.Packages)) + } + + time.Sleep(10 * time.Millisecond) + writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ + "package dep", + "", + "var _ missing", + "", + "func New() string { return \"ok\" }", + "", + }, "\n")) + + meta, err = runGoList(context.Background(), goListRequest{ + WD: appRoot, + Env: env, + Patterns: []string{"example.com/app/app"}, + NeedDeps: true, + }) + if err != nil { + t.Fatalf("runGoList(second) error = %v", err) + } + depMeta = meta["example.com/dep"] + if depMeta == nil || depMeta.Export != "" { + t.Fatalf("expected no export data for second incomplete replaced module state: %+v", depMeta) + } + + custom := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeCustom) + fallback := loadTypedPackageGraphForTest(t, appRoot, env, "example.com/app/app", ModeFallback) + compareRootPackagesOnly(t, custom.Packages, fallback.Packages, true) +} + func TestLoadTypedPackageGraphCustomExternalArtifactCacheRealAppParity(t *testing.T) { root := os.Getenv("WIRE_REAL_APP_ROOT") if root == "" { @@ -1142,6 +2880,44 @@ func compareRootPackagesOnly(t *testing.T, got []*packages.Package, want []*pack } } +func comparePackageByPath(t *testing.T, got []*packages.Package, want []*packages.Package, pkgPath string, requireTyped bool) { + t.Helper() + gotPkg := collectGraph(got)[pkgPath] + if gotPkg == nil { + t.Fatalf("missing package %q in custom graph", pkgPath) + } + wantPkg := collectGraph(want)[pkgPath] + if wantPkg == nil { + t.Fatalf("missing package %q in fallback graph", pkgPath) + } + if gotPkg.Name != wantPkg.Name { + t.Fatalf("package %q name = %q, want %q", pkgPath, gotPkg.Name, wantPkg.Name) + } + if !equalStrings(gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) { + t.Fatalf("package %q compiled files = %v, want %v", pkgPath, gotPkg.CompiledGoFiles, wantPkg.CompiledGoFiles) + } + if !equalImportPaths(gotPkg.Imports, wantPkg.Imports) { + t.Fatalf("package %q imports = %v, want %v", pkgPath, sortedImportPaths(gotPkg.Imports), sortedImportPaths(wantPkg.Imports)) + } + gotErrs := comparableErrors(gotPkg.Errors) + wantErrs := comparableErrors(wantPkg.Errors) + if len(gotErrs) != len(wantErrs) { + t.Fatalf("package %q comparable errors len = %d, want %d; got=%v want=%v", pkgPath, len(gotErrs), len(wantErrs), gotErrs, wantErrs) + } + for i := range gotErrs { + if gotErrs[i] != wantErrs[i] { + t.Fatalf("package %q comparable error[%d] = %q, want %q", pkgPath, i, gotErrs[i], wantErrs[i]) + } + } + if requireTyped { + gotTyped := gotPkg.Types != nil && gotPkg.TypesInfo != nil && len(gotPkg.Syntax) > 0 + wantTyped := wantPkg.Types != nil && wantPkg.TypesInfo != nil && len(wantPkg.Syntax) > 0 + if gotTyped != wantTyped { + t.Fatalf("package %q typed state = %v, want %v", pkgPath, gotTyped, wantTyped) + } + } +} + func collectGraph(roots []*packages.Package) map[string]*packages.Package { out := make(map[string]*packages.Package) stack := append([]*packages.Package(nil), roots...) @@ -1159,6 +2935,68 @@ func collectGraph(roots []*packages.Package) map[string]*packages.Package { return out } +func loadTypedPackageGraphForTest(t *testing.T, wd string, env []string, pkg string, mode Mode) *LazyLoadResult { + return loadTypedPackageGraphWithDiscoveryForTest(t, wd, env, pkg, mode, nil) +} + +func loadPackagesForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *PackageLoadResult { + t.Helper() + l := New() + got, err := l.LoadPackages(context.Background(), PackageLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + }) + if err != nil { + t.Fatalf("LoadPackages(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadTypedPackageGraphWithDiscoveryForTest(t *testing.T, wd string, env []string, pkg string, mode Mode, discovery *DiscoverySnapshot) *LazyLoadResult { + t.Helper() + l := New() + got, err := l.LoadTypedPackageGraph(context.Background(), LazyLoadRequest{ + WD: wd, + Env: env, + Package: pkg, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedDeps | packages.NeedTypes | packages.NeedTypesInfo | packages.NeedSyntax | packages.NeedExportFile, + LoaderMode: mode, + Fset: token.NewFileSet(), + ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { + return parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) + }, + Discovery: discovery, + }) + if err != nil { + t.Fatalf("LoadTypedPackageGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + +func loadRootGraphForTest(t *testing.T, wd string, env []string, patterns []string, mode Mode) *RootLoadResult { + t.Helper() + l := New() + got, err := l.LoadRootGraph(context.Background(), RootLoadRequest{ + WD: wd, + Env: env, + Patterns: patterns, + NeedDeps: true, + Mode: mode, + Fset: token.NewFileSet(), + }) + if err != nil { + t.Fatalf("LoadRootGraph(%q, %q) error = %v", wd, mode, err) + } + return got +} + func compiledFileCount(pkgs map[string]*packages.Package) int { total := 0 for _, pkg := range pkgs { @@ -1202,6 +3040,85 @@ func sortedImportPaths(m map[string]*packages.Package) []string { return out } +func containsPathSubstring(paths []string, needle string) bool { + for _, path := range paths { + if strings.Contains(normalizePathForCompare(path), needle) { + return true + } + } + return false +} + +func runGoModTidyForTest(t *testing.T, wd string, env []string) { + t.Helper() + cmd := exec.Command("go", "mod", "tidy") + cmd.Dir = wd + cmd.Env = env + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go mod tidy in %q error = %v: %s", wd, err, out) + } +} + +func writeModuleProxyVersion(t *testing.T, proxyDir string, modulePath string, version string, files map[string]string) { + t.Helper() + base := filepath.Join(proxyDir, filepath.FromSlash(modulePath), "@v") + if err := os.MkdirAll(base, 0o755); err != nil { + t.Fatalf("mkdir proxy dir: %v", err) + } + listPath := filepath.Join(base, "list") + appendLineIfMissing(t, listPath, version) + + modFile := "module " + modulePath + "\n\ngo 1.19\n" + writeTestFile(t, filepath.Join(base, version+".mod"), modFile) + writeTestFile(t, filepath.Join(base, version+".info"), fmt.Sprintf("{\"Version\":%q,\"Time\":\"2024-01-01T00:00:00Z\"}\n", version)) + + zipPath := filepath.Join(base, version+".zip") + zipFile, err := os.Create(zipPath) + if err != nil { + t.Fatalf("create proxy zip: %v", err) + } + defer zipFile.Close() + + zw := zip.NewWriter(zipFile) + moduleRoot := modulePath + "@" + version + writeZipFile := func(name string, contents string) { + w, err := zw.Create(moduleRoot + "/" + filepath.ToSlash(name)) + if err != nil { + t.Fatalf("create zip entry %q: %v", name, err) + } + if _, err := w.Write([]byte(contents)); err != nil { + t.Fatalf("write zip entry %q: %v", name, err) + } + } + writeZipFile("go.mod", modFile) + for name, contents := range files { + writeZipFile(name, contents) + } + if err := zw.Close(); err != nil { + t.Fatalf("close proxy zip: %v", err) + } +} + +func appendLineIfMissing(t *testing.T, path string, line string) { + t.Helper() + existing, err := os.ReadFile(path) + if err != nil && !os.IsNotExist(err) { + t.Fatalf("read %q: %v", path, err) + } + for _, existingLine := range strings.Split(strings.TrimSpace(string(existing)), "\n") { + if existingLine == line { + return + } + } + content := string(existing) + if strings.TrimSpace(content) != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += line + "\n" + writeTestFile(t, path, content) +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { From fea5e0ab08738a121576c549eb2e7916f5d35917 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:19:24 -0500 Subject: [PATCH 24/82] fix(loader): treat replaced workspace deps as local and harden runtests --- internal/loader/custom.go | 47 +++++++++++++++++++++++++++++- internal/loader/discovery_cache.go | 22 +++++++++++--- internal/loader/loader_test.go | 21 ++++++------- internal/runtests.sh | 8 ++++- 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 10f8c79..18c0f7d 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -54,9 +54,19 @@ type packageMeta struct { CompiledGoFiles []string Imports []string ImportMap map[string]string + Module *goListModule Error *goListError } +type goListModule struct { + Path string + Version string + Main bool + Dir string + GoMod string + Replace *goListModule +} + type goListError struct { Err string } @@ -882,7 +892,7 @@ func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *package if local, ok := l.isLocalCache[importPath]; ok { return local } - local := isWorkspacePackage(l.workspace, meta.Dir) + local := isLocalSourcePackage(l.workspace, meta) l.isLocalCache[importPath] = local return local } @@ -1262,6 +1272,41 @@ func isWorkspacePackage(workspaceRoot, dir string) bool { return rel != ".." && !strings.HasPrefix(rel, ".."+string(filepath.Separator)) } +func isLocalSourcePackage(workspaceRoot string, meta *packageMeta) bool { + if meta == nil { + return false + } + if isWorkspacePackage(workspaceRoot, meta.Dir) { + return true + } + mod := localSourceModule(meta.Module) + if mod == nil { + return false + } + if mod.Main { + return true + } + return canonicalLoaderPath(mod.Dir) == canonicalLoaderPath(meta.Dir) || isWorkspacePackage(canonicalLoaderPath(mod.Dir), meta.Dir) +} + +func localSourceModule(mod *goListModule) *goListModule { + if mod == nil { + return nil + } + if mod.Replace != nil { + if local := localSourceModule(mod.Replace); local != nil { + return local + } + } + if mod.Main && mod.Dir != "" { + return mod + } + if mod.Replace != nil && mod.Replace.Dir != "" { + return mod.Replace + } + return nil +} + func detectModuleRoot(start string) string { start = canonicalLoaderPath(start) for dir := start; dir != "" && dir != "." && dir != string(filepath.Separator); dir = filepath.Dir(dir) { diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 3b9fe46..e3db86b 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -71,7 +71,7 @@ func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ - Version: 2, + Version: 3, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), @@ -92,7 +92,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } locals := make([]discoveryLocalPackage, 0) for _, pkg := range meta { - if pkg == nil || !isWorkspacePackage(workspace, pkg.Dir) { + if pkg == nil || !isLocalSourcePackage(workspace, pkg) { continue } lp := discoveryLocalPackage{ @@ -116,7 +116,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { - if entry == nil || entry.Version != 2 { + if entry == nil || entry.Version != 3 { return false } for _, fm := range entry.Global { @@ -150,7 +150,7 @@ func discoveryCachePath(req goListRequest) (string, error) { NeedDeps bool Go string }{ - Version: 2, + Version: 3, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), @@ -321,6 +321,9 @@ func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { cp.ImportMap[mk] = mv } } + if v.Module != nil { + cp.Module = cloneGoListModule(v.Module) + } if v.Error != nil { errCopy := *v.Error cp.Error = &errCopy @@ -329,3 +332,14 @@ func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { } return out } + +func cloneGoListModule(in *goListModule) *goListModule { + if in == nil { + return nil + } + cp := *in + if in.Replace != nil { + cp.Replace = cloneGoListModule(in.Replace) + } + return &cp +} diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 1cb080d..065b4fb 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2198,37 +2198,38 @@ func TestLoadTypedPackageGraphCustomSequentialMutationsParity(t *testing.T) { func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { root := t.TempDir() + proxyDir := t.TempDir() homeDir := t.TempDir() + writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ + "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", + }) + writeTestFile(t, filepath.Join(root, "go.mod"), strings.Join([]string{ "module example.com/app", "", "go 1.19", "", - "require github.com/google/go-cmp v0.6.0", - "", - }, "\n")) - writeTestFile(t, filepath.Join(root, "go.sum"), strings.Join([]string{ - "github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=", - "github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=", + "require example.com/extdep v1.0.0", "", }, "\n")) writeTestFile(t, filepath.Join(root, "app", "wire.go"), strings.Join([]string{ "package app", "", - "import \"github.com/google/go-cmp/cmp\"", + "import \"example.com/extdep/pkg\"", "", - "func Init() string { return cmp.Diff(\"a\", \"b\") }", + "func Init() string { return pkg.Version() }", "", }, "\n")) env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=off", - "GONOSUMDB=*", + "GOPROXY=file://"+proxyDir, + "GOSUMDB=off", "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache", ) + runGoModTidyForTest(t, root, env) first := loadTypedPackageGraphForTest(t, root, env, "example.com/app/app", ModeCustom) if len(first.Packages) != 1 { diff --git a/internal/runtests.sh b/internal/runtests.sh index 28877c1..7d2ddcb 100755 --- a/internal/runtests.sh +++ b/internal/runtests.sh @@ -16,6 +16,9 @@ # https://coderwall.com/p/fkfaqq/safer-bash-scripts-with-set-euxo-pipefail set -euo pipefail +export GOCACHE="${GOCACHE:-/tmp/gocache}" +export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" + if [[ $# -gt 0 ]]; then echo "usage: runtests.sh" 1>&2 exit 64 @@ -34,7 +37,10 @@ fi echo echo "Ensuring .go files are formatted with gofmt -s..." -mapfile -t go_files < <(find . -name '*.go' -type f | grep -v testdata) +go_files=() +while IFS= read -r file; do + go_files+=("$file") +done < <(find . -name '*.go' -type f | grep -v testdata) DIFF="$(gofmt -s -d "${go_files[@]}")" if [ -n "$DIFF" ]; then echo "FAIL: please run gofmt -s and commit the result" From 568d2e08e3de5035cc6257405399e210f9394a51 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:35:40 -0500 Subject: [PATCH 25/82] fix(loader): make cache-hardening tests and runtests portable --- internal/loader/loader_test.go | 51 ++++++++++++++++++++++++++++------ internal/runtests.sh | 9 ++++-- 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 065b4fb..4e94afc 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2200,6 +2200,8 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { root := t.TempDir() proxyDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", @@ -2226,8 +2228,8 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { "HOME="+homeDir, "GOPROXY=file://"+proxyDir, "GOSUMDB=off", - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, ) runGoModTidyForTest(t, root, env) @@ -2252,6 +2254,8 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T proxyDir := t.TempDir() artifactDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeModuleProxyVersion(t, proxyDir, "example.com/extdep", "v1.0.0", map[string]string{ "pkg/pkg.go": "package pkg\n\nfunc Version() string { return \"v1.0.0\" }\n", @@ -2281,8 +2285,8 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T "HOME="+homeDir, "GOPROXY=file://"+proxyDir, "GOSUMDB=off", - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, loaderArtifactEnv+"=1", loaderArtifactDirEnv+"="+artifactDir, ) @@ -2386,6 +2390,8 @@ func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") appRoot := filepath.Join(root, "appmod") + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), "package dep\n\nfunc New() string { return \"ok\" }\n") @@ -2404,7 +2410,7 @@ func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { meta, err := runGoList(context.Background(), goListRequest{ WD: appRoot, - Env: append(os.Environ(), "GOCACHE=/tmp/gocache", "GOMODCACHE=/tmp/gomodcache"), + Env: append(os.Environ(), "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir), Patterns: []string{"example.com/app/app"}, NeedDeps: true, }) @@ -2426,6 +2432,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *tes appRoot := filepath.Join(root, "appmod") artifactDir := t.TempDir() homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ @@ -2456,8 +2464,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithExportDataWarmParity(t *tes env := append(os.Environ(), "HOME="+homeDir, - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, loaderArtifactEnv+"=1", loaderArtifactDirEnv+"="+artifactDir, ) @@ -2517,6 +2525,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t * depRoot := filepath.Join(root, "depmod") appRoot := filepath.Join(root, "appmod") homeDir := t.TempDir() + goCacheDir := tempCacheDirForTest(t, "wire-gocache-") + goModCacheDir := tempCacheDirForTest(t, "wire-gomodcache-") writeTestFile(t, filepath.Join(depRoot, "go.mod"), "module example.com/dep\n\ngo 1.19\n") writeTestFile(t, filepath.Join(depRoot, "dep.go"), strings.Join([]string{ @@ -2549,8 +2559,8 @@ func TestLoadTypedPackageGraphCustomReplaceTargetWithoutExportDataWarmParity(t * env := append(os.Environ(), "HOME="+homeDir, - "GOCACHE=/tmp/gocache", - "GOMODCACHE=/tmp/gomodcache", + "GOCACHE="+goCacheDir, + "GOMODCACHE="+goModCacheDir, ) meta, err := runGoList(context.Background(), goListRequest{ @@ -3120,6 +3130,29 @@ func appendLineIfMissing(t *testing.T, path string, line string) { writeTestFile(t, path, content) } +func tempCacheDirForTest(t *testing.T, pattern string) string { + t.Helper() + dir, err := os.MkdirTemp("", pattern) + if err != nil { + t.Fatalf("MkdirTemp(%q) error = %v", pattern, err) + } + t.Cleanup(func() { + _ = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + if info.IsDir() { + _ = os.Chmod(path, 0o755) + return nil + } + _ = os.Chmod(path, 0o644) + return nil + }) + _ = os.RemoveAll(dir) + }) + return dir +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { diff --git a/internal/runtests.sh b/internal/runtests.sh index 7d2ddcb..905e319 100755 --- a/internal/runtests.sh +++ b/internal/runtests.sh @@ -16,8 +16,13 @@ # https://coderwall.com/p/fkfaqq/safer-bash-scripts-with-set-euxo-pipefail set -euo pipefail -export GOCACHE="${GOCACHE:-/tmp/gocache}" -export GOMODCACHE="${GOMODCACHE:-/tmp/gomodcache}" +tmp_root="${TMPDIR:-${RUNNER_TEMP:-}}" +if [[ -z "${tmp_root}" ]]; then + tmp_root="$(mktemp -d)" +fi + +export GOCACHE="${GOCACHE:-${tmp_root}/gocache}" +export GOMODCACHE="${GOMODCACHE:-${tmp_root}/gomodcache}" if [[ $# -gt 0 ]]; then echo "usage: runtests.sh" 1>&2 From 114d1740875b182ff6453ed7232df47d1b76a693 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 16:40:02 -0500 Subject: [PATCH 26/82] fix(loader): use valid file GOPROXY URLs in proxy-based tests --- internal/loader/loader_test.go | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 4e94afc..d539347 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -23,6 +23,7 @@ import ( "go/parser" "go/token" "go/types" + "net/url" "os" "os/exec" "path/filepath" @@ -2226,7 +2227,7 @@ func TestDiscoveryCacheInvalidatesOnGoSumResolutionChange(t *testing.T) { env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=file://"+proxyDir, + "GOPROXY="+fileURLForTest(t, proxyDir), "GOSUMDB=off", "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir, @@ -2283,7 +2284,7 @@ func TestLoadTypedPackageGraphCustomExternalVersionChangeBustsCache(t *testing.T env := append(os.Environ(), "HOME="+homeDir, - "GOPROXY=file://"+proxyDir, + "GOPROXY="+fileURLForTest(t, proxyDir), "GOSUMDB=off", "GOCACHE="+goCacheDir, "GOMODCACHE="+goModCacheDir, @@ -3153,6 +3154,15 @@ func tempCacheDirForTest(t *testing.T, pattern string) string { return dir } +func fileURLForTest(t *testing.T, path string) string { + t.Helper() + u := &url.URL{ + Scheme: "file", + Path: filepath.ToSlash(path), + } + return u.String() +} + type importerFuncForTest func(string) (*types.Package, error) func (f importerFuncForTest) Import(path string) (*types.Package, error) { From 53890d7a7037c92179c6a6542789174be0f74c3f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 17:00:43 -0500 Subject: [PATCH 27/82] fix(loader): format file GOPROXY URLs correctly on windows --- internal/loader/loader_test.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index d539347..a0d96bc 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -23,7 +23,6 @@ import ( "go/parser" "go/token" "go/types" - "net/url" "os" "os/exec" "path/filepath" @@ -3156,11 +3155,11 @@ func tempCacheDirForTest(t *testing.T, pattern string) string { func fileURLForTest(t *testing.T, path string) string { t.Helper() - u := &url.URL{ - Scheme: "file", - Path: filepath.ToSlash(path), + slashed := filepath.ToSlash(path) + if !strings.HasPrefix(slashed, "/") { + slashed = "/" + slashed } - return u.String() + return "file://" + slashed } type importerFuncForTest func(string) (*types.Package, error) From efa02144b77bb6e99e10d0f1d412b8b231e393f2 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 17:04:50 -0500 Subject: [PATCH 28/82] fix(loader): normalize test path comparisons across platforms --- internal/loader/loader_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index a0d96bc..05cfaa7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -3182,9 +3182,9 @@ func normalizePathForCompare(path string) string { return "" } if resolved, err := filepath.EvalSymlinks(path); err == nil && resolved != "" { - return filepath.Clean(resolved) + return filepath.ToSlash(filepath.Clean(resolved)) } - return filepath.Clean(path) + return filepath.ToSlash(filepath.Clean(path)) } func comparableErrors(errs []packages.Error) []string { From 541acdfe8fd2dc42d53f1988ed1f4738c0c1af4a Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:31:37 -0500 Subject: [PATCH 29/82] refactor: remove unused loader and wire helpers --- internal/loader/discovery_cache.go | 8 -------- internal/wire/loader_validation.go | 6 +----- internal/wire/wire.go | 20 -------------------- 3 files changed, 1 insertion(+), 33 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index e3db86b..9d7d932 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -60,14 +60,6 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { return clonePackageMetaMap(entry.Meta), true } -func writeDiscoveryCache(req goListRequest, meta map[string]*packageMeta) { - entry, err := buildDiscoveryCacheEntry(req, meta) - if err != nil { - return - } - _ = saveDiscoveryCacheEntry(req, entry) -} - func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ diff --git a/internal/wire/loader_validation.go b/internal/wire/loader_validation.go index 6868b7b..cde4d60 100644 --- a/internal/wire/loader_validation.go +++ b/internal/wire/loader_validation.go @@ -20,11 +20,7 @@ import ( "github.com/goforj/wire/internal/loader" ) -func loaderValidationMode(ctx context.Context, wd string, env []string) bool { - return effectiveLoaderMode(ctx, wd, env) != loader.ModeFallback -} - -func effectiveLoaderMode(ctx context.Context, wd string, env []string) loader.Mode { +func effectiveLoaderMode(_ context.Context, _ string, env []string) loader.Mode { mode := loader.ModeFromEnv(env) if mode != loader.ModeAuto { return mode diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 3d787f3..2459723 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -170,26 +170,6 @@ func detectOutputDir(paths []string) (string, error) { return dir, nil } -func manifestOutputPkgPathsFromGenerated(generated []GenerateResult) []string { - if len(generated) == 0 { - return nil - } - seen := make(map[string]struct{}, len(generated)) - out := make([]string, 0, len(generated)) - for _, gen := range generated { - if gen.PkgPath == "" { - continue - } - if _, ok := seen[gen.PkgPath]; ok { - continue - } - seen[gen.PkgPath] = struct{}{} - out = append(out, gen.PkgPath) - } - sort.Strings(out) - return out -} - // generateInjectors generates the injectors for a given package. func generateInjectors(oc *objectCache, g *gen, pkg *packages.Package) (injectorFiles []*ast.File, _ []error) { injectorFiles = make([]*ast.File, 0, len(pkg.Syntax)) From 40ab1445c19c08ecb79cb6a61832413c7feaf807 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:31:40 -0500 Subject: [PATCH 30/82] refactor: dedupe command and custom loader helpers --- cmd/wire/gen_cmd.go | 43 +------ cmd/wire/generate_runner.go | 59 +++++++++ cmd/wire/logging.go | 133 ++++++++++++++++++++ cmd/wire/main.go | 125 ------------------- cmd/wire/watch_cmd.go | 48 +------- internal/loader/custom.go | 234 ++++++++++++++---------------------- 6 files changed, 287 insertions(+), 355 deletions(-) create mode 100644 cmd/wire/generate_runner.go create mode 100644 cmd/wire/logging.go diff --git a/cmd/wire/gen_cmd.go b/cmd/wire/gen_cmd.go index aceefee..246caa5 100644 --- a/cmd/wire/gen_cmd.go +++ b/cmd/wire/gen_cmd.go @@ -19,9 +19,7 @@ import ( "flag" "log" "os" - "time" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -66,7 +64,6 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa return subcommands.ExitFailure } defer stop() - totalStart := time.Now() ctx = withTiming(ctx, cmd.profile.timings) wd, err := os.Getwd() @@ -83,46 +80,8 @@ func (cmd *genCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interfa opts.PrefixOutputFile = cmd.prefixFileName opts.Tags = cmd.tags - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, os.Environ(), packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") + if !runGenerateCommand(ctx, wd, os.Environ(), packages(f), opts, cmd.profile.timings) { return subcommands.ExitFailure } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return subcommands.ExitSuccess - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - // No Wire output. Maybe errors, maybe no Wire directives. - continue - } - if wrote, err := out.CommitWithStatus(); err == nil { - if wrote { - logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return subcommands.ExitFailure - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) return subcommands.ExitSuccess } diff --git a/cmd/wire/generate_runner.go b/cmd/wire/generate_runner.go new file mode 100644 index 0000000..4dc7b20 --- /dev/null +++ b/cmd/wire/generate_runner.go @@ -0,0 +1,59 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/goforj/wire/internal/wire" +) + +func runGenerateCommand(ctx context.Context, wd string, env []string, patterns []string, opts *wire.GenerateOptions, timings bool) bool { + totalStart := time.Now() + genStart := time.Now() + outs, errs := wire.Generate(ctx, wd, env, patterns, opts) + logTiming(timings, "wire.Generate", genStart) + if len(errs) > 0 { + logErrors(errs) + log.Println("generate failed") + return false + } + if len(outs) == 0 { + logTiming(timings, "total", totalStart) + return true + } + success := true + writeStart := time.Now() + for _, out := range outs { + if len(out.Errs) > 0 { + logErrors(out.Errs) + log.Printf("%s: generate failed\n", out.PkgPath) + success = false + } + if len(out.Content) == 0 { + continue + } + if wrote, err := out.CommitWithStatus(); err == nil { + if wrote { + logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } else { + logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) + } + } else { + log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) + success = false + } + } + if !success { + log.Println("at least one generate failure") + return false + } + logTiming(timings, "writes", writeStart) + logTiming(timings, "total", totalStart) + return true +} + +func formatDuration(d time.Duration) string { + return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) +} diff --git a/cmd/wire/logging.go b/cmd/wire/logging.go new file mode 100644 index 0000000..7479f13 --- /dev/null +++ b/cmd/wire/logging.go @@ -0,0 +1,133 @@ +package main + +import ( + "fmt" + "io" + "os" + "strings" +) + +const ( + ansiRed = "\033[1;31m" + ansiGreen = "\033[1;32m" + ansiReset = "\033[0m" + successSig = "✓ " + errorSig = "x " + maxLoggedErrorLines = 5 +) + +func logErrors(errs []error) { + for _, err := range errs { + msg := truncateLoggedError(formatLoggedError(err)) + if strings.Contains(msg, "\n") { + logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) + continue + } + logMultilineError(msg) + } +} + +func formatLoggedError(err error) string { + if err == nil { + return "" + } + msg := err.Error() + if strings.HasPrefix(msg, "inject ") { + return "solve failed\n" + msg + } + if idx := strings.Index(msg, ": inject "); idx >= 0 { + return "solve failed\n" + msg + } + return msg +} + +func truncateLoggedError(msg string) string { + if msg == "" { + return "" + } + lines := strings.Split(msg, "\n") + if len(lines) <= maxLoggedErrorLines { + return msg + } + omitted := len(lines) - maxLoggedErrorLines + lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) + return strings.Join(lines, "\n") +} + +func logMultilineError(msg string) { + writeErrorLog(os.Stderr, msg) +} + +func logSuccessf(format string, args ...interface{}) { + writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) +} + +func shouldColorStderr() bool { + return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) +} + +func shouldColorOutput(isTTY bool, term string) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { + return false + } + if forceColorEnabled() { + return true + } + if term == "" || term == "dumb" { + return false + } + return isTTY +} + +func forceColorEnabled() bool { + return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" +} + +func stderrIsTTY() bool { + info, err := os.Stderr.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func writeErrorLog(w io.Writer, msg string) { + line := errorSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, colorizeLines(line)) + return + } + _, _ = io.WriteString(w, line) +} + +func writeStatusLog(w io.Writer, msg string) { + line := successSig + "wire: " + msg + if !strings.HasSuffix(line, "\n") { + line += "\n" + } + if shouldColorStderr() { + _, _ = io.WriteString(w, ansiGreen+line+ansiReset) + return + } + _, _ = io.WriteString(w, line) +} + +func colorizeLines(s string) string { + if s == "" { + return "" + } + parts := strings.SplitAfter(s, "\n") + var b strings.Builder + for _, part := range parts { + if part == "" { + continue + } + b.WriteString(ansiRed) + b.WriteString(part) + b.WriteString(ansiReset) + } + return b.String() +} diff --git a/cmd/wire/main.go b/cmd/wire/main.go index c13b850..ada16d2 100644 --- a/cmd/wire/main.go +++ b/cmd/wire/main.go @@ -21,7 +21,6 @@ import ( "context" "flag" "fmt" - "io" "io/ioutil" "log" "os" @@ -35,15 +34,6 @@ import ( "github.com/google/subcommands" ) -const ( - ansiRed = "\033[1;31m" - ansiGreen = "\033[1;32m" - ansiReset = "\033[0m" - successSig = "✓ " - errorSig = "x " - maxLoggedErrorLines = 5 -) - // main wires up subcommands and executes the selected command. func main() { subcommands.Register(subcommands.CommandsCommand(), "") @@ -212,118 +202,3 @@ func newGenerateOptions(headerFile string) (*wire.GenerateOptions, error) { } // logErrors logs each error with consistent formatting. -func logErrors(errs []error) { - for _, err := range errs { - msg := truncateLoggedError(formatLoggedError(err)) - if strings.Contains(msg, "\n") { - logMultilineError("\n " + strings.ReplaceAll(msg, "\n", "\n ")) - continue - } - logMultilineError(msg) - } -} - -func formatLoggedError(err error) string { - if err == nil { - return "" - } - msg := err.Error() - if strings.HasPrefix(msg, "inject ") { - return "solve failed\n" + msg - } - if idx := strings.Index(msg, ": inject "); idx >= 0 { - return "solve failed\n" + msg - } - return msg -} - -func truncateLoggedError(msg string) string { - if msg == "" { - return "" - } - lines := strings.Split(msg, "\n") - if len(lines) <= maxLoggedErrorLines { - return msg - } - omitted := len(lines) - maxLoggedErrorLines - lines = append(lines[:maxLoggedErrorLines], fmt.Sprintf("... (%d additional lines omitted)", omitted)) - return strings.Join(lines, "\n") -} - -func logMultilineError(msg string) { - writeErrorLog(os.Stderr, msg) -} - -func logSuccessf(format string, args ...interface{}) { - writeStatusLog(os.Stderr, fmt.Sprintf(format, args...)) -} - -func shouldColorStderr() bool { - return shouldColorOutput(stderrIsTTY(), os.Getenv("TERM")) -} - -func shouldColorOutput(isTTY bool, term string) bool { - if os.Getenv("NO_COLOR") != "" || os.Getenv("CLICOLOR") == "0" { - return false - } - if forceColorEnabled() { - return true - } - if term == "" || term == "dumb" { - return false - } - return isTTY -} - -func forceColorEnabled() bool { - return os.Getenv("FORCE_COLOR") != "" || os.Getenv("CLICOLOR_FORCE") != "" -} - -func stderrIsTTY() bool { - info, err := os.Stderr.Stat() - if err != nil { - return false - } - return (info.Mode() & os.ModeCharDevice) != 0 -} - -func writeErrorLog(w io.Writer, msg string) { - line := errorSig + "wire: " + msg - if !strings.HasSuffix(line, "\n") { - line += "\n" - } - if shouldColorStderr() { - _, _ = io.WriteString(w, colorizeLines(line)) - return - } - _, _ = io.WriteString(w, line) -} - -func writeStatusLog(w io.Writer, msg string) { - line := successSig + "wire: " + msg - if !strings.HasSuffix(line, "\n") { - line += "\n" - } - if shouldColorStderr() { - _, _ = io.WriteString(w, ansiGreen+line+ansiReset) - return - } - _, _ = io.WriteString(w, line) -} - -func colorizeLines(s string) string { - if s == "" { - return "" - } - parts := strings.SplitAfter(s, "\n") - var b strings.Builder - for _, part := range parts { - if part == "" { - continue - } - b.WriteString(ansiRed) - b.WriteString(part) - b.WriteString(ansiReset) - } - return b.String() -} diff --git a/cmd/wire/watch_cmd.go b/cmd/wire/watch_cmd.go index ebdfa0e..45b6bc4 100644 --- a/cmd/wire/watch_cmd.go +++ b/cmd/wire/watch_cmd.go @@ -27,7 +27,6 @@ import ( "time" "github.com/fsnotify/fsnotify" - "github.com/goforj/wire/internal/wire" "github.com/google/subcommands" ) @@ -102,47 +101,7 @@ func (cmd *watchCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter env := os.Environ() runGenerate := func() { - totalStart := time.Now() - genStart := time.Now() - outs, errs := wire.Generate(ctx, wd, env, packages(f), opts) - logTiming(cmd.profile.timings, "wire.Generate", genStart) - if len(errs) > 0 { - logErrors(errs) - log.Println("generate failed") - return - } - if len(outs) == 0 { - logTiming(cmd.profile.timings, "total", totalStart) - return - } - success := true - writeStart := time.Now() - for _, out := range outs { - if len(out.Errs) > 0 { - logErrors(out.Errs) - log.Printf("%s: generate failed\n", out.PkgPath) - success = false - } - if len(out.Content) == 0 { - continue - } - if wrote, err := out.CommitWithStatus(); err == nil { - if wrote { - logSuccessf("%s: wrote %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } else { - logSuccessf("%s: unchanged %s (%s)", out.PkgPath, out.OutputPath, formatDuration(time.Since(totalStart))) - } - } else { - log.Printf("%s: failed to write %s: %v\n", out.PkgPath, out.OutputPath, err) - success = false - } - } - if !success { - log.Println("at least one generate failure") - return - } - logTiming(cmd.profile.timings, "writes", writeStart) - logTiming(cmd.profile.timings, "total", totalStart) + _ = runGenerateCommand(ctx, wd, env, packages(f), opts, cmd.profile.timings) } root, err := moduleRoot(wd, env) @@ -332,11 +291,6 @@ func moduleRoot(wd string, env []string) (string, error) { return filepath.Dir(path), nil } -// formatDuration renders a short millisecond duration for log output. -func formatDuration(d time.Duration) string { - return fmt.Sprintf("%.2fms", float64(d)/float64(time.Millisecond)) -} - // watchWithFSNotify runs the watcher using native filesystem notifications. func watchWithFSNotify(root string, onChange func()) error { watcher, err := fsnotify.NewWatcher() diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 18c0f7d..afdafc0 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -188,22 +188,8 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } pkgs := make(map[string]*packages.Package, len(meta)) for path, m := range meta { - pkgs[path] = &packages.Package{ - ID: m.ImportPath, - Name: m.Name, - PkgPath: m.ImportPath, - GoFiles: append([]string(nil), metaFiles(m)...), - CompiledGoFiles: append([]string(nil), metaFiles(m)...), - ExportFile: m.Export, - Imports: make(map[string]*packages.Package), - } - if m.Error != nil && strings.TrimSpace(m.Error.Err) != "" { - pkgs[path].Errors = append(pkgs[path].Errors, packages.Error{ - Pos: "-", - Msg: m.Error.Err, - Kind: packages.ListError, - }) - } + pkgs[path] = packageStub(nil, m) + appendPackageMetaError(pkgs[path], m) } for path, m := range meta { pkg := pkgs[path] @@ -217,12 +203,10 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes } } } - roots := make([]*packages.Package, 0, len(req.Patterns)) - for _, m := range meta { - if m.DepOnly { - continue - } - if pkg := pkgs[m.ImportPath]; pkg != nil { + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { roots = append(roots, pkg) } } @@ -272,23 +256,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz if fset == nil { fset = token.NewFileSet() } - l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: map[string]struct{}{req.Package: {}}, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, - } + l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) @@ -298,26 +266,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz return nil, err } l.stats.rootLoad = time.Since(rootLoadStart) - logDuration(ctx, "loader.custom.lazy.read_files.cumulative", l.stats.read) - logDuration(ctx, "loader.custom.lazy.parse_files.cumulative", l.stats.parse) - logDuration(ctx, "loader.custom.lazy.typecheck.cumulative", l.stats.typecheck) - logDuration(ctx, "loader.custom.lazy.read_files.local.cumulative", l.stats.localRead) - logDuration(ctx, "loader.custom.lazy.read_files.external.cumulative", l.stats.externalRead) - logDuration(ctx, "loader.custom.lazy.parse_files.local.cumulative", l.stats.localParse) - logDuration(ctx, "loader.custom.lazy.parse_files.external.cumulative", l.stats.externalParse) - logDuration(ctx, "loader.custom.lazy.typecheck.local.cumulative", l.stats.localTypecheck) - logDuration(ctx, "loader.custom.lazy.typecheck.external.cumulative", l.stats.externalTypecheck) - logDuration(ctx, "loader.custom.lazy.artifact_read", l.stats.artifactRead) - logDuration(ctx, "loader.custom.lazy.artifact_path", l.stats.artifactPath) - logDuration(ctx, "loader.custom.lazy.artifact_decode", l.stats.artifactDecode) - logDuration(ctx, "loader.custom.lazy.artifact_import_link", l.stats.artifactImportLink) - logDuration(ctx, "loader.custom.lazy.artifact_write", l.stats.artifactWrite) - logDuration(ctx, "loader.custom.lazy.artifact_prefetch.wall", l.stats.artifactPrefetch) - logDuration(ctx, "loader.custom.lazy.root_load.wall", l.stats.rootLoad) - logDuration(ctx, "loader.custom.lazy.discovery.wall", l.stats.discovery) - logInt(ctx, "loader.custom.lazy.artifact_hits", l.stats.artifactHits) - logInt(ctx, "loader.custom.lazy.artifact_misses", l.stats.artifactMisses) - logInt(ctx, "loader.custom.lazy.artifact_writes", l.stats.artifactWrites) + logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ Packages: []*packages.Package{root}, Backend: ModeCustom, @@ -345,42 +294,21 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo fset = token.NewFileSet() } targets := make(map[string]struct{}) - for _, m := range meta { - if m.DepOnly { - continue - } - targets[m.ImportPath] = struct{}{} + for _, path := range nonDepRootImportPaths(meta) { + targets[path] = struct{}{} } if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } - l := &customTypedGraphLoader{ - workspace: detectModuleRoot(req.WD), - ctx: ctx, - env: append([]string(nil), req.Env...), - fset: fset, - meta: meta, - targets: targets, - parseFile: req.ParseFile, - packages: make(map[string]*packages.Package, len(meta)), - typesPkgs: make(map[string]*types.Package, len(meta)), - importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), - loading: make(map[string]bool, len(meta)), - isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), - artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), - stats: typedLoadStats{discovery: discoveryDuration}, - } + l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) rootLoadStart := time.Now() - roots := make([]*packages.Package, 0, len(targets)) - for _, m := range meta { - if m.DepOnly { - continue - } - root, err := l.loadPackage(m.ImportPath) + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + root, err := l.loadPackage(path) if err != nil { return nil, err } @@ -388,26 +316,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - logDuration(ctx, "loader.custom.typed.read_files.cumulative", l.stats.read) - logDuration(ctx, "loader.custom.typed.parse_files.cumulative", l.stats.parse) - logDuration(ctx, "loader.custom.typed.typecheck.cumulative", l.stats.typecheck) - logDuration(ctx, "loader.custom.typed.read_files.local.cumulative", l.stats.localRead) - logDuration(ctx, "loader.custom.typed.read_files.external.cumulative", l.stats.externalRead) - logDuration(ctx, "loader.custom.typed.parse_files.local.cumulative", l.stats.localParse) - logDuration(ctx, "loader.custom.typed.parse_files.external.cumulative", l.stats.externalParse) - logDuration(ctx, "loader.custom.typed.typecheck.local.cumulative", l.stats.localTypecheck) - logDuration(ctx, "loader.custom.typed.typecheck.external.cumulative", l.stats.externalTypecheck) - logDuration(ctx, "loader.custom.typed.artifact_read", l.stats.artifactRead) - logDuration(ctx, "loader.custom.typed.artifact_path", l.stats.artifactPath) - logDuration(ctx, "loader.custom.typed.artifact_decode", l.stats.artifactDecode) - logDuration(ctx, "loader.custom.typed.artifact_import_link", l.stats.artifactImportLink) - logDuration(ctx, "loader.custom.typed.artifact_write", l.stats.artifactWrite) - logDuration(ctx, "loader.custom.typed.artifact_prefetch.wall", l.stats.artifactPrefetch) - logDuration(ctx, "loader.custom.typed.root_load.wall", l.stats.rootLoad) - logDuration(ctx, "loader.custom.typed.discovery.wall", l.stats.discovery) - logInt(ctx, "loader.custom.typed.artifact_hits", l.stats.artifactHits) - logInt(ctx, "loader.custom.typed.artifact_misses", l.stats.artifactMisses) - logInt(ctx, "loader.custom.typed.artifact_writes", l.stats.artifactWrites) + logTypedLoadStats(ctx, "typed", l.stats) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, @@ -424,22 +333,8 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error } v.loading[path] = true defer delete(v.loading, path) - pkg := &packages.Package{ - ID: meta.ImportPath, - Name: meta.Name, - PkgPath: meta.ImportPath, - Fset: v.fset, - GoFiles: append([]string(nil), metaFiles(meta)...), - CompiledGoFiles: append([]string(nil), metaFiles(meta)...), - Imports: make(map[string]*packages.Package), - ExportFile: meta.Export, - } - if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { - pkg.Errors = append(pkg.Errors, packages.Error{ - Pos: "-", - Msg: meta.Error.Err, - Kind: packages.ListError, - }) + pkg := packageStub(v.fset, meta) + if appendPackageMetaError(pkg, meta) { return pkg, nil } files, errs := v.parseFiles(metaFiles(meta)) @@ -563,16 +458,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er } if pkg == nil { - pkg = &packages.Package{ - ID: meta.ImportPath, - Name: meta.Name, - PkgPath: meta.ImportPath, - Fset: l.fset, - GoFiles: append([]string(nil), metaFiles(meta)...), - CompiledGoFiles: append([]string(nil), metaFiles(meta)...), - Imports: make(map[string]*packages.Package), - ExportFile: meta.Export, - } + pkg = packageStub(l.fset, meta) l.packages[path] = pkg } useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) @@ -600,13 +486,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er files, parseErrs := l.parseFiles(metaFiles(meta), isLocal) pkg.Errors = append(pkg.Errors, parseErrs...) if len(files) == 0 { - if meta.Error != nil && strings.TrimSpace(meta.Error.Err) != "" { - pkg.Errors = append(pkg.Errors, packages.Error{ - Pos: "-", - Msg: meta.Error.Err, - Kind: packages.ListError, - }) - } + appendPackageMetaError(pkg, meta) return pkg, nil } @@ -1358,7 +1238,27 @@ func envValue(env []string, key string) string { return "" } -func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { +func newCustomTypedGraphLoader(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, parseFile ParseFileFunc, discoveryDuration time.Duration) *customTypedGraphLoader { + return &customTypedGraphLoader{ + workspace: detectModuleRoot(wd), + ctx: ctx, + env: append([]string(nil), env...), + fset: fset, + meta: meta, + targets: targets, + parseFile: parseFile, + packages: make(map[string]*packages.Package, len(meta)), + typesPkgs: make(map[string]*types.Package, len(meta)), + importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), + loading: make(map[string]bool, len(meta)), + isLocalCache: make(map[string]bool, len(meta)), + localSemanticOK: make(map[string]bool, len(meta)), + artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), + stats: typedLoadStats{discovery: discoveryDuration}, + } +} + +func packageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { if meta == nil { return nil } @@ -1374,6 +1274,58 @@ func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Packag } } +func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { + if pkg == nil || meta == nil || meta.Error == nil || strings.TrimSpace(meta.Error.Err) == "" { + return false + } + pkg.Errors = append(pkg.Errors, packages.Error{ + Pos: "-", + Msg: meta.Error.Err, + Kind: packages.ListError, + }) + return true +} + +func nonDepRootImportPaths(meta map[string]*packageMeta) []string { + paths := make([]string, 0, len(meta)) + for _, m := range meta { + if m == nil || m.DepOnly { + continue + } + paths = append(paths, m.ImportPath) + } + sort.Strings(paths) + return paths +} + +func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { + prefix := "loader.custom." + mode + logDuration(ctx, prefix+".read_files.cumulative", stats.read) + logDuration(ctx, prefix+".parse_files.cumulative", stats.parse) + logDuration(ctx, prefix+".typecheck.cumulative", stats.typecheck) + logDuration(ctx, prefix+".read_files.local.cumulative", stats.localRead) + logDuration(ctx, prefix+".read_files.external.cumulative", stats.externalRead) + logDuration(ctx, prefix+".parse_files.local.cumulative", stats.localParse) + logDuration(ctx, prefix+".parse_files.external.cumulative", stats.externalParse) + logDuration(ctx, prefix+".typecheck.local.cumulative", stats.localTypecheck) + logDuration(ctx, prefix+".typecheck.external.cumulative", stats.externalTypecheck) + logDuration(ctx, prefix+".artifact_read", stats.artifactRead) + logDuration(ctx, prefix+".artifact_path", stats.artifactPath) + logDuration(ctx, prefix+".artifact_decode", stats.artifactDecode) + logDuration(ctx, prefix+".artifact_import_link", stats.artifactImportLink) + logDuration(ctx, prefix+".artifact_write", stats.artifactWrite) + logDuration(ctx, prefix+".artifact_prefetch.wall", stats.artifactPrefetch) + logDuration(ctx, prefix+".root_load.wall", stats.rootLoad) + logDuration(ctx, prefix+".discovery.wall", stats.discovery) + logInt(ctx, prefix+".artifact_hits", stats.artifactHits) + logInt(ctx, prefix+".artifact_misses", stats.artifactMisses) + logInt(ctx, prefix+".artifact_writes", stats.artifactWrites) +} + +func touchedPackageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { + return packageStub(fset, meta) +} + func metadataMatchesFingerprint(pkgPath string, meta map[string]*packageMeta, local []LocalPackageFingerprint) bool { for _, fp := range local { if fp.PkgPath != pkgPath { From 0921843ca888dff91d2ac9774962e7751a111cdf Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:37:26 -0500 Subject: [PATCH 31/82] refactor: make loader artifact policy explicit --- internal/loader/custom.go | 45 ++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index afdafc0..87cab98 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -134,6 +134,11 @@ type typedLoadStats struct { artifactWrites int } +type artifactPolicy struct { + read bool + write bool +} + func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationRequest) (*TouchedValidationResult, error) { if len(req.Touched) == 0 { return &TouchedValidationResult{Backend: ModeCustom}, nil @@ -461,8 +466,8 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg = packageStub(l.fset, meta) l.packages[path] = pkg } - useArtifact := l.shouldUseArtifact(path, meta, isTarget, isLocal) - if useArtifact { + artifactPolicy := l.artifactPolicy(meta, isTarget, isLocal) + if artifactPolicy.read { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() for _, imp := range meta.Imports { @@ -559,13 +564,13 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er pkg.Types = tpkg pkg.TypesInfo = info pkg.Errors = append(pkg.Errors, typeErrors...) - if shouldWriteArtifact(l.env, isTarget) && len(pkg.Errors) == 0 { + if artifactPolicy.write && len(pkg.Errors) == 0 { _ = l.writeArtifact(meta, tpkg, isLocal) } return pkg, nil } -func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) bool { +func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMeta) bool { if meta == nil { return false } @@ -581,6 +586,19 @@ func (l *customTypedGraphLoader) useLocalSemanticArtifact(meta *packageMeta) boo return art.Supported } +func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { + if !loaderArtifactEnabled(l.env) || isTarget { + return artifactPolicy{} + } + policy := artifactPolicy{write: true} + if !isLocal { + policy.read = true + return policy + } + policy.read = l.localSemanticArtifactSupported(meta) + return policy +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -656,16 +674,6 @@ func (l *customTypedGraphLoader) readArtifact(path string, meta *packageMeta, is return tpkg, true } -func (l *customTypedGraphLoader) shouldUseArtifact(path string, meta *packageMeta, isTarget, isLocal bool) bool { - if !loaderArtifactEnabled(l.env) || isTarget { - return false - } - if !isLocal { - return true - } - return l.useLocalSemanticArtifact(meta) -} - func (l *customTypedGraphLoader) prefetchArtifacts() { if !loaderArtifactEnabled(l.env) { return @@ -674,7 +682,7 @@ func (l *customTypedGraphLoader) prefetchArtifacts() { for path, meta := range l.meta { _, isTarget := l.targets[path] isLocal := l.isLocalPackage(path, meta) - if l.shouldUseArtifact(path, meta, isTarget, isLocal) { + if l.artifactPolicy(meta, isTarget, isLocal).read { candidates = append(candidates, path) } } @@ -761,13 +769,6 @@ func (l *customTypedGraphLoader) writeArtifact(meta *packageMeta, pkg *types.Pac return nil } -func shouldWriteArtifact(env []string, isTarget bool) bool { - if !loaderArtifactEnabled(env) || isTarget { - return false - } - return true -} - func (l *customTypedGraphLoader) isLocalPackage(importPath string, meta *packageMeta) bool { if local, ok := l.isLocalCache[importPath]; ok { return local From bb7af77028392f6f756b4b8326758367e4a61712 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:38:07 -0500 Subject: [PATCH 32/82] refactor: dedupe custom loader import linking --- internal/loader/custom.go | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 87cab98..827b003 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -470,16 +470,8 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er if artifactPolicy.read { if typed, ok := l.readArtifact(path, meta, isLocal); ok { linkStart := time.Now() - for _, imp := range meta.Imports { - target := imp - if mapped := meta.ImportMap[imp]; mapped != "" { - target = mapped - } - dep, err := l.loadPackage(target) - if err != nil { - return nil, err - } - pkg.Imports[imp] = dep + if err := l.linkPackageImports(pkg, meta); err != nil { + return nil, err } l.stats.artifactImportLink += time.Since(linkStart) pkg.Types = typed @@ -520,10 +512,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) dep, err := l.loadPackage(target) if err != nil { return nil, err @@ -599,6 +588,17 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL return policy } +func (l *customTypedGraphLoader) linkPackageImports(pkg *packages.Package, meta *packageMeta) error { + for _, imp := range meta.Imports { + dep, err := l.loadPackage(resolvedImportTarget(meta, imp)) + if err != nil { + return err + } + pkg.Imports[imp] = dep + } + return nil +} + func (l *customTypedGraphLoader) checkFiles(path string, checker *types.Checker, files []*ast.File) (err error) { defer func() { if r := recover(); r != nil { @@ -1287,6 +1287,16 @@ func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { return true } +func resolvedImportTarget(meta *packageMeta, importPath string) string { + if meta == nil { + return importPath + } + if mapped := meta.ImportMap[importPath]; mapped != "" { + return mapped + } + return importPath +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From df22b6a67bd167af9fd282bfdff5d1d79086934c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:38:47 -0500 Subject: [PATCH 33/82] refactor: share import target resolution in custom loader --- internal/loader/custom.go | 20 ++++---------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 827b003..a498499 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -199,10 +199,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes for path, m := range meta { pkg := pkgs[path] for _, imp := range m.Imports { - target := imp - if mapped := m.ImportMap[imp]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(m, imp) if dep := pkgs[target]; dep != nil { pkg.Imports[imp] = dep } @@ -362,10 +359,7 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) if _, ok := v.touched[target]; ok { if typed := v.packages[target]; typed != nil && typed.Complete() { if depMeta := v.meta[target]; depMeta != nil { @@ -859,10 +853,7 @@ func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, if importPath == "unsafe" { return types.Unsafe, nil } - target := importPath - if mapped := meta.ImportMap[importPath]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, importPath) if _, ok := v.touched[target]; ok { checked, err := v.validatePackage(target) if err != nil { @@ -1023,10 +1014,7 @@ func (v *customValidator) validateDeclaredImports(meta *packageMeta, files []*as if path == "" { continue } - target := path - if mapped := meta.ImportMap[path]; mapped != "" { - target = mapped - } + target := resolvedImportTarget(meta, path) name := importName(spec) if name != "_" && name != "." { if _, ok := used[name]; !ok { From a857e0b2e77754431d7175e7e0ca62e50230da7f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:39:36 -0500 Subject: [PATCH 34/82] refactor: centralize types info setup in custom loader --- internal/loader/custom.go | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index a498499..67803fd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -347,14 +347,7 @@ func (v *customValidator) validatePackage(path string) (*packages.Package, error tpkg := types.NewPackage(meta.ImportPath, meta.Name) v.packages[meta.ImportPath] = tpkg - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info := newTypesInfo() importer := importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { return types.Unsafe, nil @@ -489,14 +482,7 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er needFullState := isTarget || isLocal var info *types.Info if needFullState { - info = &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info = newTypesInfo() } var typeErrors []packages.Error cfg := &types.Config{ @@ -840,14 +826,7 @@ func (v *customValidator) loadDependencyFromSource(path string) (*types.Package, if len(errs) > 0 { return nil, unsupportedError{reason: "dependency parse error"} } - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Defs: make(map[*ast.Ident]types.Object), - Uses: make(map[*ast.Ident]types.Object), - Implicits: make(map[ast.Node]types.Object), - Scopes: make(map[ast.Node]*types.Scope), - Selections: make(map[*ast.SelectorExpr]*types.Selection), - } + info := newTypesInfo() cfg := &types.Config{ Importer: importerFunc(func(importPath string) (*types.Package, error) { if importPath == "unsafe" { @@ -1285,6 +1264,17 @@ func resolvedImportTarget(meta *packageMeta, importPath string) string { return importPath } +func newTypesInfo() *types.Info { + return &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + Defs: make(map[*ast.Ident]types.Object), + Uses: make(map[*ast.Ident]types.Object), + Implicits: make(map[ast.Node]types.Object), + Scopes: make(map[ast.Node]*types.Scope), + Selections: make(map[*ast.SelectorExpr]*types.Selection), + } +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From 11b5498234fe7fe9afadeea8a9026460219ab73f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:40:25 -0500 Subject: [PATCH 35/82] refactor: share parse error conversion in custom loader --- internal/loader/custom.go | 42 +++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 67803fd..4cdc1e8 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -865,18 +865,7 @@ func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Er } f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) if err != nil { - switch typed := err.(type) { - case scanner.ErrorList: - for _, parseErr := range typed { - errs = append(errs, packages.Error{ - Pos: parseErr.Pos.String(), - Msg: parseErr.Msg, - Kind: packages.ParseError, - }) - } - default: - errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) - } + errs = appendParseErrors(errs, name, err) } if f != nil { files = append(files, f) @@ -920,18 +909,7 @@ func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*as l.stats.externalParse += parseDuration } if err != nil { - switch typed := err.(type) { - case scanner.ErrorList: - for _, parseErr := range typed { - errs = append(errs, packages.Error{ - Pos: parseErr.Pos.String(), - Msg: parseErr.Msg, - Kind: packages.ParseError, - }) - } - default: - errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) - } + errs = appendParseErrors(errs, name, err) } if f != nil { files = append(files, f) @@ -1275,6 +1253,22 @@ func newTypesInfo() *types.Info { } } +func appendParseErrors(errs []packages.Error, name string, err error) []packages.Error { + switch typed := err.(type) { + case scanner.ErrorList: + for _, parseErr := range typed { + errs = append(errs, packages.Error{ + Pos: parseErr.Pos.String(), + Msg: parseErr.Msg, + Kind: packages.ParseError, + }) + } + default: + errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) + } + return errs +} + func nonDepRootImportPaths(meta map[string]*packageMeta) []string { paths := make([]string, 0, len(meta)) for _, m := range meta { From 7bd7c65d6246a514a5fc6f7692fcf79c88c5e489 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:41:07 -0500 Subject: [PATCH 36/82] refactor: share source parsing in custom loader --- internal/loader/custom.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 4cdc1e8..3529350 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -863,7 +863,7 @@ func (v *customValidator) parseFiles(names []string) ([]*ast.File, []packages.Er errs = append(errs, packages.Error{Pos: name + ":1", Msg: err.Error(), Kind: packages.ParseError}) continue } - f, err := parser.ParseFile(v.fset, name, src, parser.AllErrors|parser.ParseComments) + f, err := parseGoSourceFile(v.fset, nil, name, src) if err != nil { errs = appendParseErrors(errs, name, err) } @@ -896,11 +896,7 @@ func (l *customTypedGraphLoader) parseFiles(names []string, isLocal bool) ([]*as } var f *ast.File parseStart := time.Now() - if l.parseFile != nil { - f, err = l.parseFile(l.fset, name, src) - } else { - f, err = parser.ParseFile(l.fset, name, src, parser.AllErrors|parser.ParseComments) - } + f, err = parseGoSourceFile(l.fset, l.parseFile, name, src) parseDuration := time.Since(parseStart) l.stats.parse += parseDuration if isLocal { @@ -1253,6 +1249,13 @@ func newTypesInfo() *types.Info { } } +func parseGoSourceFile(fset *token.FileSet, parseFile ParseFileFunc, name string, src []byte) (*ast.File, error) { + if parseFile != nil { + return parseFile(fset, name, src) + } + return parser.ParseFile(fset, name, src, parser.AllErrors|parser.ParseComments) +} + func appendParseErrors(errs []packages.Error, name string, err error) []packages.Error { switch typed := err.(type) { case scanner.ErrorList: From 6e3fed58e84c6bce7b18e863e355ff34d26c59ea Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:47:19 -0500 Subject: [PATCH 37/82] refactor: centralize semantic artifact cache inputs --- internal/wire/parse.go | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 08a3c8a..34167fc 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -822,10 +822,11 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa if art, ok := oc.semantic[pkg.PkgPath]; ok { return art } - if len(oc.env) == 0 || len(pkg.GoFiles) == 0 { + importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) + if !ok { return nil } - art, err := semanticcache.Read(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles) + art, err := semanticcache.Read(oc.env, importPath, packageName, files) if err != nil { return nil } @@ -838,7 +839,8 @@ func (oc *objectCache) recordSemanticArtifacts() { return } for _, pkg := range oc.packages { - if pkg == nil || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil || len(pkg.GoFiles) == 0 { + importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) + if !ok || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil { continue } art := buildSemanticArtifact(pkg) @@ -846,8 +848,15 @@ func (oc *objectCache) recordSemanticArtifacts() { continue } oc.semantic[pkg.PkgPath] = art - _ = semanticcache.Write(oc.env, pkg.PkgPath, pkg.Name, pkg.GoFiles, art) + _ = semanticcache.Write(oc.env, importPath, packageName, files, art) + } +} + +func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { + if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { + return "", "", nil, false } + return pkg.PkgPath, pkg.Name, pkg.GoFiles, true } func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { From c2abf4ef0d2caadf6454fb47cc9d62932bd46901 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:48:16 -0500 Subject: [PATCH 38/82] refactor: isolate semantic artifact cache io --- internal/wire/parse.go | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 34167fc..dee48b0 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -826,7 +826,7 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa if !ok { return nil } - art, err := semanticcache.Read(oc.env, importPath, packageName, files) + art, err := readSemanticArtifact(oc.env, importPath, packageName, files) if err != nil { return nil } @@ -848,7 +848,7 @@ func (oc *objectCache) recordSemanticArtifacts() { continue } oc.semantic[pkg.PkgPath] = art - _ = semanticcache.Write(oc.env, importPath, packageName, files, art) + _ = writeSemanticArtifact(oc.env, importPath, packageName, files, art) } } @@ -859,6 +859,14 @@ func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, pa return pkg.PkgPath, pkg.Name, pkg.GoFiles, true } +func readSemanticArtifact(env []string, importPath, packageName string, files []string) (*semanticcache.PackageArtifact, error) { + return semanticcache.Read(env, importPath, packageName, files) +} + +func writeSemanticArtifact(env []string, importPath, packageName string, files []string, art *semanticcache.PackageArtifact) error { + return semanticcache.Write(env, importPath, packageName, files, art) +} + func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { return nil From 40ea02b0303b1b9c3aeb7be849209a1ec08427e3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:49:05 -0500 Subject: [PATCH 39/82] refactor: isolate semantic provider set artifact lookup --- internal/wire/parse.go | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index dee48b0..80411ef 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -562,23 +562,10 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { - pkg := oc.packages[obj.Pkg().Path()] - if pkg == nil { - return nil, false, nil - } - art := oc.semanticArtifact(pkg) - if art == nil || !art.Supported { - return nil, false, nil - } - setArt, ok := art.Vars[obj.Name()] + setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { return nil, false, nil } - for _, item := range setArt.Items { - if item.Kind == "bind" { - return nil, false, nil - } - } pset := &ProviderSet{ Pos: obj.Pos(), PkgPath: obj.Pkg().Path(), @@ -640,6 +627,27 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, return pset, true, nil } +func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { + pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return semanticcache.ProviderSetArtifact{}, false + } + art := oc.semanticArtifact(pkg) + if art == nil || !art.Supported { + return semanticcache.ProviderSetArtifact{}, false + } + setArt, ok := art.Vars[obj.Name()] + if !ok { + return semanticcache.ProviderSetArtifact{}, false + } + for _, item := range setArt.Items { + if item.Kind == "bind" { + return semanticcache.ProviderSetArtifact{}, false + } + } + return setArt, true +} + func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { From 5f295650675a7ef27313893b599fa091b258de59 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:51:21 -0500 Subject: [PATCH 40/82] refactor: extract semantic provider set item application --- internal/wire/parse.go | 82 ++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 80411ef..b988a2a 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -573,44 +573,8 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, } ec := new(errorCollector) for _, item := range setArt.Items { - switch item.Kind { - case "func": - providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Providers = append(pset.Providers, providerObj) - case "set": - setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Imports = append(pset.Imports, setObj) - case "bind": - binding, errs := oc.semanticBinding(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Bindings = append(pset.Bindings, binding) - case "struct": - providerObj, errs := oc.semanticStructProvider(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Providers = append(pset.Providers, providerObj) - case "fields": - fields, errs := oc.semanticFields(item) - if len(errs) > 0 { - ec.add(errs...) - continue - } - pset.Fields = append(pset.Fields, fields...) - default: - ec.add(fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)) + if errs := oc.applySemanticProviderSetItem(pset, item); len(errs) > 0 { + ec.add(errs...) } } if len(ec.errors) > 0 { @@ -648,6 +612,48 @@ func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcach return setArt, true } +func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { + switch item.Kind { + case "func": + providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) + if len(errs) > 0 { + return errs + } + pset.Providers = append(pset.Providers, providerObj) + return nil + case "set": + setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) + if len(errs) > 0 { + return errs + } + pset.Imports = append(pset.Imports, setObj) + return nil + case "bind": + binding, errs := oc.semanticBinding(item) + if len(errs) > 0 { + return errs + } + pset.Bindings = append(pset.Bindings, binding) + return nil + case "struct": + providerObj, errs := oc.semanticStructProvider(item) + if len(errs) > 0 { + return errs + } + pset.Providers = append(pset.Providers, providerObj) + return nil + case "fields": + fields, errs := oc.semanticFields(item) + if len(errs) > 0 { + return errs + } + pset.Fields = append(pset.Fields, fields...) + return nil + default: + return []error{fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)} + } +} + func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { From 91845b1a1462ec646fe56615268abc375414e6e0 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:52:32 -0500 Subject: [PATCH 41/82] refactor: share semantic struct field helpers --- internal/wire/parse.go | 69 +++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index b988a2a..04a3301 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -720,29 +720,11 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem IsStruct: true, Out: []types.Type{out, types.NewPointer(out)}, } - if item.AllFields { - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } - } else { - for _, fieldName := range item.FieldNames { - f := lookupStructField(st, fieldName) - if f == nil { - return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} - } - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } + args, errs := semanticStructProviderInputs(st, item) + if len(errs) > 0 { + return nil, errs } + provider.Args = args return provider, nil } @@ -757,9 +739,9 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact } fields := make([]*Field, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { - v := lookupStructField(structType, fieldName) - if v == nil { - return nil, []error{fmt.Errorf("field %q not found", fieldName)} + v, err := requiredStructField(structType, fieldName) + if err != nil { + return nil, []error{err} } out := []types.Type{v.Type()} if ptrToField { @@ -776,6 +758,35 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact return fields, nil } +func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { + if item.AllFields { + args := make([]ProviderInput, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + f := st.Field(i) + args = append(args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + return args, nil + } + args := make([]ProviderInput, 0, len(item.FieldNames)) + for _, fieldName := range item.FieldNames { + f, err := requiredStructField(st, fieldName) + if err != nil { + return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} + } + args = append(args, ProviderInput{ + Type: f.Type(), + FieldName: f.Name(), + }) + } + return args, nil +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { @@ -829,6 +840,14 @@ func lookupStructField(st *types.Struct, name string) *types.Var { return nil } +func requiredStructField(st *types.Struct, name string) (*types.Var, error) { + v := lookupStructField(st, name) + if v == nil { + return nil, fmt.Errorf("field %q not found", name) + } + return v, nil +} + func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil { return nil From 9d894b29550e8bde205b1d16bc56ba11197ef6fc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 18:57:42 -0500 Subject: [PATCH 42/82] refactor: share semantic package object lookup --- internal/wire/parse.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 04a3301..36a6feb 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -655,11 +655,10 @@ func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item sema } func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - pkg := oc.packages[importPath] - if pkg == nil || pkg.Types == nil { - return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, []error{err} } - obj := pkg.Types.Scope().Lookup(name) fn, ok := obj.(*types.Func) if !ok || fn == nil { return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} @@ -668,11 +667,10 @@ func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []e } func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - pkg := oc.packages[importPath] - if pkg == nil || pkg.Types == nil { - return nil, []error{fmt.Errorf("missing typed package for %s", importPath)} + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, []error{err} } - obj := pkg.Types.Scope().Lookup(name) v, ok := obj.(*types.Var) if !ok || v == nil || !isProviderSetType(v.Type()) { return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} @@ -800,11 +798,10 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - pkg := oc.packages[ref.ImportPath] - if pkg == nil || pkg.Types == nil { - return nil, fmt.Errorf("missing typed package for %s", ref.ImportPath) + obj, err := oc.lookupPackageObject(ref.ImportPath, ref.Name) + if err != nil { + return nil, err } - obj := pkg.Types.Scope().Lookup(ref.Name) typeName, ok := obj.(*types.TypeName) if !ok || typeName == nil { return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) @@ -812,6 +809,14 @@ func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeN return typeName, nil } +func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { + pkg := oc.packages[importPath] + if pkg == nil || pkg.Types == nil { + return nil, fmt.Errorf("missing typed package for %s", importPath) + } + return pkg.Types.Scope().Lookup(name), nil +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { From 3775c4cb7ac0e703145edca7762ed8a8ea4c0bbb Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:00:27 -0500 Subject: [PATCH 43/82] refactor: share semantic output type assembly --- internal/wire/parse.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 36a6feb..0fa211b 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -716,7 +716,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem Name: typeName.Name(), Pos: typeName.Pos(), IsStruct: true, - Out: []types.Type{out, types.NewPointer(out)}, + Out: typeAndPointer(out), } args, errs := semanticStructProviderInputs(st, item) if len(errs) > 0 { @@ -741,16 +741,12 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact if err != nil { return nil, []error{err} } - out := []types.Type{v.Type()} - if ptrToField { - out = append(out, types.NewPointer(v.Type())) - } fields = append(fields, &Field{ Parent: parent, Name: v.Name(), Pkg: v.Pkg(), Pos: v.Pos(), - Out: out, + Out: fieldOutputTypes(v.Type(), ptrToField), }) } return fields, nil @@ -785,6 +781,18 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS return args, nil } +func typeAndPointer(typ types.Type) []types.Type { + return []types.Type{typ, types.NewPointer(typ)} +} + +func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { + out := []types.Type{typ} + if includePointer { + out = append(out, types.NewPointer(typ)) + } + return out +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { From 2091c4b13f0678bf9e33c7144aebd90e2d80ae57 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:04:01 -0500 Subject: [PATCH 44/82] refactor: share struct provider shell assembly --- internal/wire/parse.go | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0fa211b..3a19ba4 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -711,13 +711,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem if !ok { return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} } - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: typeName.Pos(), - IsStruct: true, - Out: typeAndPointer(out), - } + provider := newStructProvider(typeName, typeAndPointer(out)) args, errs := semanticStructProviderInputs(st, item) if len(errs) > 0 { return nil, errs @@ -793,6 +787,16 @@ func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { return out } +func newStructProvider(typeName types.Object, out []types.Type) *Provider { + return &Provider{ + Pkg: typeName.Pkg(), + Name: typeName.Name(), + Pos: typeName.Pos(), + IsStruct: true, + Out: out, + } +} + func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { typeName, err := oc.semanticTypeName(ref) if err != nil { @@ -1448,14 +1452,9 @@ func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) notePosition(fset.Position(pos), fmt.Errorf("using struct literal to inject %s is deprecated and will be removed in the next release; use wire.Struct instead", typeName.Type()))) - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: pos, - Args: make([]ProviderInput, st.NumFields()), - IsStruct: true, - Out: []types.Type{out, types.NewPointer(out)}, - } + provider := newStructProvider(typeName, typeAndPointer(out)) + provider.Pos = pos + provider.Args = make([]ProviderInput, st.NumFields()) for i := 0; i < st.NumFields(); i++ { f := st.Field(i) provider.Args[i] = ProviderInput{ @@ -1496,13 +1495,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call stExpr := call.Args[0].(*ast.CallExpr) typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector - provider := &Provider{ - Pkg: typeName.Pkg(), - Name: typeName.Name(), - Pos: typeName.Pos(), - IsStruct: true, - Out: []types.Type{structPtr.Elem(), structPtr}, - } + provider := newStructProvider(typeName, []types.Type{structPtr.Elem(), structPtr}) if allFields(call) { for i := 0; i < st.NumFields(); i++ { if isPrevented(st.Tag(i)) { From a7cd4cde95d2787679cdbcd4e4a8d157eb693eef Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:08:45 -0500 Subject: [PATCH 45/82] refactor: share allowed struct field inputs --- internal/wire/parse.go | 52 ++++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 3a19ba4..88657a7 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -748,18 +748,7 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { if item.AllFields { - args := make([]ProviderInput, 0, st.NumFields()) - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - args = append(args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } - return args, nil + return providerInputsForAllowedStructFields(st), nil } args := make([]ProviderInput, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { @@ -767,14 +756,29 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS if err != nil { return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} } - args = append(args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) + args = append(args, providerInputForVar(f)) } return args, nil } +func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { + args := make([]ProviderInput, 0, st.NumFields()) + for i := 0; i < st.NumFields(); i++ { + if isPrevented(st.Tag(i)) { + continue + } + args = append(args, providerInputForVar(st.Field(i))) + } + return args +} + +func providerInputForVar(v *types.Var) ProviderInput { + return ProviderInput{ + Type: v.Type(), + FieldName: v.Name(), + } +} + func typeAndPointer(typ types.Type) []types.Type { return []types.Type{typ, types.NewPointer(typ)} } @@ -1497,16 +1501,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call typeName := qualifiedIdentObject(info, stExpr.Args[0]) // should be either an identifier or selector provider := newStructProvider(typeName, []types.Type{structPtr.Elem(), structPtr}) if allFields(call) { - for i := 0; i < st.NumFields(); i++ { - if isPrevented(st.Tag(i)) { - continue - } - f := st.Field(i) - provider.Args = append(provider.Args, ProviderInput{ - Type: f.Type(), - FieldName: f.Name(), - }) - } + provider.Args = providerInputsForAllowedStructFields(st) } else { provider.Args = make([]ProviderInput, len(call.Args)-1) for i := 1; i < len(call.Args); i++ { @@ -1514,10 +1509,7 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - provider.Args[i-1] = ProviderInput{ - Type: v.Type(), - FieldName: v.Name(), - } + provider.Args[i-1] = providerInputForVar(v) } } for i := 0; i < len(provider.Args); i++ { From 051fb75473a893cb77470f64137c7213e1f8e9ae Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:09:55 -0500 Subject: [PATCH 46/82] refactor: share selected struct field inputs --- internal/wire/parse.go | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 88657a7..9f9a26d 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -750,24 +750,32 @@ func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderS if item.AllFields { return providerInputsForAllowedStructFields(st), nil } - args := make([]ProviderInput, 0, len(item.FieldNames)) + fields := make([]*types.Var, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { f, err := requiredStructField(st, fieldName) if err != nil { return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} } - args = append(args, providerInputForVar(f)) + fields = append(fields, f) } - return args, nil + return providerInputsForVars(fields), nil } func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { - args := make([]ProviderInput, 0, st.NumFields()) + fields := make([]*types.Var, 0, st.NumFields()) for i := 0; i < st.NumFields(); i++ { if isPrevented(st.Tag(i)) { continue } - args = append(args, providerInputForVar(st.Field(i))) + fields = append(fields, st.Field(i)) + } + return providerInputsForVars(fields) +} + +func providerInputsForVars(vars []*types.Var) []ProviderInput { + args := make([]ProviderInput, 0, len(vars)) + for _, v := range vars { + args = append(args, providerInputForVar(v)) } return args } @@ -1503,14 +1511,15 @@ func processStructProvider(fset *token.FileSet, info *types.Info, call *ast.Call if allFields(call) { provider.Args = providerInputsForAllowedStructFields(st) } else { - provider.Args = make([]ProviderInput, len(call.Args)-1) + fields := make([]*types.Var, 0, len(call.Args)-1) for i := 1; i < len(call.Args); i++ { v, err := checkField(call.Args[i], st) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - provider.Args[i-1] = providerInputForVar(v) + fields = append(fields, v) } + provider.Args = providerInputsForVars(fields) } for i := 0; i < len(provider.Args); i++ { for j := 0; j < i; j++ { From 7f6195a67c2f397ae839fc4debbb0d88434ca760 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:10:59 -0500 Subject: [PATCH 47/82] refactor: share field output assembly for FieldsOf --- internal/wire/parse.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 9f9a26d..3342aad 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1713,18 +1713,12 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - out := []types.Type{v.Type()} - if isPtrToStruct { - // If the field is from a pointer to a struct, then - // wire.Fields also provides a pointer to the field. - out = append(out, types.NewPointer(v.Type())) - } fields = append(fields, &Field{ Parent: structPtr.Elem(), Name: v.Name(), Pkg: v.Pkg(), Pos: v.Pos(), - Out: out, + Out: fieldOutputTypes(v.Type(), isPtrToStruct), }) } return fields, nil From 326425316c696b346b1f8fea9c8024a9d433be48 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:11:56 -0500 Subject: [PATCH 48/82] refactor: share quoted struct field lookup --- internal/wire/parse.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 3342aad..6395d00 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -877,6 +877,15 @@ func requiredStructField(st *types.Struct, name string) (*types.Var, error) { return v, nil } +func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, int) { + for i := 0; i < st.NumFields(); i++ { + if strings.EqualFold(strconv.Quote(st.Field(i).Name()), quotedName) { + return st.Field(i), i + } + } + return nil, -1 +} + func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { if pkg == nil { return nil @@ -1731,13 +1740,12 @@ func checkField(f ast.Expr, st *types.Struct) (*types.Var, error) { if !ok { return nil, fmt.Errorf("%v must be a string with the field name", f) } - for i := 0; i < st.NumFields(); i++ { - if strings.EqualFold(strconv.Quote(st.Field(i).Name()), b.Value) { - if isPrevented(st.Tag(i)) { - return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) - } - return st.Field(i), nil + v, i := lookupQuotedStructField(st, b.Value) + if v != nil { + if isPrevented(st.Tag(i)) { + return nil, fmt.Errorf("%s is prevented from injecting by wire", b.Value) } + return v, nil } return nil, fmt.Errorf("%s is not a field of %s", b.Value, st.String()) } From 41f4d7dfecffec3be083097a737f206934e4a2ff Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:14:35 -0500 Subject: [PATCH 49/82] refactor: share semantic pointer expansion --- internal/wire/parse.go | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 6395d00..2c36c76 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -788,13 +788,13 @@ func providerInputForVar(v *types.Var) ProviderInput { } func typeAndPointer(typ types.Type) []types.Type { - return []types.Type{typ, types.NewPointer(typ)} + return []types.Type{typ, applyTypePointers(typ, 1)} } func fieldOutputTypes(typ types.Type, includePointer bool) []types.Type { out := []types.Type{typ} if includePointer { - out = append(out, types.NewPointer(typ)) + out = append(out, applyTypePointers(typ, 1)) } return out } @@ -814,11 +814,7 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro if err != nil { return nil, err } - var typ types.Type = typeName.Type() - for i := 0; i < ref.Pointer; i++ { - typ = types.NewPointer(typ) - } - return typ, nil + return applyTypePointers(typeName.Type(), ref.Pointer), nil } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { @@ -841,6 +837,13 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } +func applyTypePointers(typ types.Type, count int) types.Type { + for i := 0; i < count; i++ { + typ = types.NewPointer(typ) + } + return typ +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { From dea5a636ad6252f2862cfbd7167aacfc69fde04c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:15:23 -0500 Subject: [PATCH 50/82] refactor: reuse field parent struct resolution --- internal/wire/parse.go | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2c36c76..f52200f 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1697,22 +1697,10 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) return nil, notePosition(fset.Position(call.Pos()), fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } - - var struc *types.Struct - isPtrToStruct := false - switch t := structPtr.Elem().Underlying().(type) { - case *types.Pointer: - struc, ok = t.Elem().Underlying().(*types.Struct) - if !ok { - return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(struc, nil))) - } - isPtrToStruct = true - case *types.Struct: - struc = t - default: + struc, isPtrToStruct, err := structFromFieldsParent(structPtr) + if err != nil { return nil, notePosition(fset.Position(call.Pos()), - fmt.Errorf(firstArgReqFormat, types.TypeString(t, nil))) + fmt.Errorf(firstArgReqFormat, types.TypeString(structType, nil))) } if struc.NumFields() < len(call.Args)-1 { return nil, notePosition(fset.Position(call.Pos()), From 06060679318057694ce47598f9ec9656aa29360f Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:17:47 -0500 Subject: [PATCH 51/82] refactor: share field object assembly --- internal/wire/parse.go | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f52200f..f1a43b1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -735,13 +735,7 @@ func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact if err != nil { return nil, []error{err} } - fields = append(fields, &Field{ - Parent: parent, - Name: v.Name(), - Pkg: v.Pkg(), - Pos: v.Pos(), - Out: fieldOutputTypes(v.Type(), ptrToField), - }) + fields = append(fields, newField(parent, v, ptrToField)) } return fields, nil } @@ -787,6 +781,16 @@ func providerInputForVar(v *types.Var) ProviderInput { } } +func newField(parent types.Type, v *types.Var, includePointer bool) *Field { + return &Field{ + Parent: parent, + Name: v.Name(), + Pkg: v.Pkg(), + Pos: v.Pos(), + Out: fieldOutputTypes(v.Type(), includePointer), + } +} + func typeAndPointer(typ types.Type) []types.Type { return []types.Type{typ, applyTypePointers(typ, 1)} } @@ -1713,13 +1717,7 @@ func processFieldsOf(fset *token.FileSet, info *types.Info, call *ast.CallExpr) if err != nil { return nil, notePosition(fset.Position(call.Pos()), err) } - fields = append(fields, &Field{ - Parent: structPtr.Elem(), - Name: v.Name(), - Pkg: v.Pkg(), - Pos: v.Pos(), - Out: fieldOutputTypes(v.Type(), isPtrToStruct), - }) + fields = append(fields, newField(structPtr.Elem(), v, isPtrToStruct)) } return fields, nil } From f9d735f21db84886fb1e4e7b1961b9d0c79a87dd Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:19:12 -0500 Subject: [PATCH 52/82] refactor: share named struct type resolution --- internal/wire/parse.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f1a43b1..353d743 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -706,8 +706,7 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem if err != nil { return nil, []error{err} } - out := typeName.Type() - st, ok := out.Underlying().(*types.Struct) + out, st, ok := namedStructType(typeName) if !ok { return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} } @@ -848,6 +847,12 @@ func applyTypePointers(typ types.Type, count int) types.Type { return typ } +func namedStructType(typeName types.Object) (types.Type, *types.Struct, bool) { + out := typeName.Type() + st, ok := out.Underlying().(*types.Struct) + return out, st, ok +} + func structFromFieldsParent(parent types.Type) (*types.Struct, bool, error) { ptr, ok := parent.(*types.Pointer) if !ok { @@ -1468,8 +1473,7 @@ func funcOutput(sig *types.Signature) (outputSignature, error) { // It will not support any new feature introduced after v0.2. Please use the new // wire.Struct syntax for those. func processStructLiteralProvider(fset *token.FileSet, typeName *types.TypeName) (*Provider, []error) { - out := typeName.Type() - st, ok := out.Underlying().(*types.Struct) + out, st, ok := namedStructType(typeName) if !ok { return nil, []error{fmt.Errorf("%v does not name a struct", typeName)} } From 4477c75c6ed812d59384b70678f57b9fb1f93932 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:20:16 -0500 Subject: [PATCH 53/82] refactor: share semantic type name lookup --- internal/wire/parse.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 353d743..4220d32 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -821,15 +821,7 @@ func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, erro } func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - obj, err := oc.lookupPackageObject(ref.ImportPath, ref.Name) - if err != nil { - return nil, err - } - typeName, ok := obj.(*types.TypeName) - if !ok || typeName == nil { - return nil, fmt.Errorf("%s.%s is not a named type", ref.ImportPath, ref.Name) - } - return typeName, nil + return oc.lookupPackageTypeName(ref.ImportPath, ref.Name) } func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { @@ -840,6 +832,18 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } +func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.TypeName, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + typeName, ok := obj.(*types.TypeName) + if !ok || typeName == nil { + return nil, fmt.Errorf("%s.%s is not a named type", importPath, name) + } + return typeName, nil +} + func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) From 2c0446de652895b389e96a14349b926b223adaed Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:21:11 -0500 Subject: [PATCH 54/82] refactor: share semantic package member lookup --- internal/wire/parse.go | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 4220d32..0677cd0 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -655,26 +655,18 @@ func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item sema } func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - obj, err := oc.lookupPackageObject(importPath, name) + fn, err := oc.lookupPackageFunc(importPath, name) if err != nil { return nil, []error{err} } - fn, ok := obj.(*types.Func) - if !ok || fn == nil { - return nil, []error{fmt.Errorf("%s.%s is not a provider function", importPath, name)} - } return processFuncProvider(oc.fset, fn) } func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - obj, err := oc.lookupPackageObject(importPath, name) + v, err := oc.lookupProviderSetVar(importPath, name) if err != nil { return nil, []error{err} } - v, ok := obj.(*types.Var) - if !ok || v == nil || !isProviderSetType(v.Type()) { - return nil, []error{fmt.Errorf("%s.%s is not a provider set", importPath, name)} - } item, errs := oc.get(v) if len(errs) > 0 { return nil, errs @@ -844,6 +836,30 @@ func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.Ty return typeName, nil } +func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + fn, ok := obj.(*types.Func) + if !ok || fn == nil { + return nil, fmt.Errorf("%s.%s is not a provider function", importPath, name) + } + return fn, nil +} + +func (oc *objectCache) lookupProviderSetVar(importPath, name string) (*types.Var, error) { + obj, err := oc.lookupPackageObject(importPath, name) + if err != nil { + return nil, err + } + v, ok := obj.(*types.Var) + if !ok || v == nil || !isProviderSetType(v.Type()) { + return nil, fmt.Errorf("%s.%s is not a provider set", importPath, name) + } + return v, nil +} + func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) From 848371ff0edfc0a4ba9f06a9423d809e314ad648 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:22:09 -0500 Subject: [PATCH 55/82] refactor: share semantic error wrapping --- internal/wire/parse.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 0677cd0..430cb4f 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -681,11 +681,11 @@ func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSe func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { iface, err := oc.semanticType(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } provided, err := oc.semanticType(item.Type2) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } return &IfaceBinding{ Iface: iface, @@ -696,11 +696,11 @@ func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifac func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { typeName, err := oc.semanticTypeName(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } out, st, ok := namedStructType(typeName) if !ok { - return nil, []error{fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)} + return nil, semanticErrors(fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)) } provider := newStructProvider(typeName, typeAndPointer(out)) args, errs := semanticStructProviderInputs(st, item) @@ -714,17 +714,17 @@ func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItem func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { parent, err := oc.semanticType(item.Type) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } structType, ptrToField, err := structFromFieldsParent(parent) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } fields := make([]*Field, 0, len(item.FieldNames)) for _, fieldName := range item.FieldNames { v, err := requiredStructField(structType, fieldName) if err != nil { - return nil, []error{err} + return nil, semanticErrors(err) } fields = append(fields, newField(parent, v, ptrToField)) } @@ -765,6 +765,10 @@ func providerInputsForVars(vars []*types.Var) []ProviderInput { return args } +func semanticErrors(err error) []error { + return []error{err} +} + func providerInputForVar(v *types.Var) ProviderInput { return ProviderInput{ Type: v.Type(), From 3944cbb876cca053e5a8b184389a1b71a14a0228 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:23:07 -0500 Subject: [PATCH 56/82] refactor: share provider set finalization --- internal/wire/parse.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 430cb4f..834f8c2 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -580,12 +580,7 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, if len(ec.errors) > 0 { return nil, true, ec.errors } - var errs []error - pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) - if len(errs) > 0 { - return nil, true, errs - } - if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { + if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { return nil, true, errs } return pset, true, nil @@ -1361,15 +1356,22 @@ func (oc *objectCache) processNewSet(info *types.Info, pkgPath string, call *ast if len(ec.errors) > 0 { return nil, ec.errors } + if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { + return nil, errs + } + return pset, nil +} + +func (oc *objectCache) finalizeProviderSet(pset *ProviderSet) []error { var errs []error pset.providerMap, pset.srcMap, errs = buildProviderMap(oc.fset, oc.hasher, pset) if len(errs) > 0 { - return nil, errs + return errs } if errs := verifyAcyclic(pset.providerMap, oc.hasher); len(errs) > 0 { - return nil, errs + return errs } - return pset, nil + return nil } // structArgType attempts to interpret an expression as a simple struct type. From 12858a2ff65ee60de69cf55e8c1f531e87e58810 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 19:42:51 -0500 Subject: [PATCH 57/82] refactor: add isolated output cache gate --- internal/wire/output_cache.go | 8 ++++++- internal/wire/output_cache_test.go | 38 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 internal/wire/output_cache_test.go diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index 35eacfe..b95a514 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -17,7 +17,10 @@ import ( "github.com/goforj/wire/internal/loader" ) -const outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" +const ( + outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + outputCacheEnabledEnv = "WIRE_OUTPUT_CACHE" +) type outputCacheEntry struct { Version int @@ -104,6 +107,9 @@ func outputCacheEnabled(ctx context.Context, wd string, env []string) bool { if effectiveLoaderMode(ctx, wd, env) == loader.ModeFallback { return false } + if envValue(env, outputCacheEnabledEnv) == "0" { + return false + } return envValue(env, "WIRE_LOADER_ARTIFACTS") != "0" } diff --git a/internal/wire/output_cache_test.go b/internal/wire/output_cache_test.go new file mode 100644 index 0000000..a74621b --- /dev/null +++ b/internal/wire/output_cache_test.go @@ -0,0 +1,38 @@ +package wire + +import ( + "context" + "testing" +) + +func TestOutputCacheEnabled(t *testing.T) { + ctx := context.Background() + tests := []struct { + name string + env []string + want bool + }{ + { + name: "enabled with artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=1"}, + want: true, + }, + { + name: "disabled without artifacts", + env: []string{"WIRE_LOADER_ARTIFACTS=0"}, + want: false, + }, + { + name: "disabled by dedicated env", + env: []string{"WIRE_LOADER_ARTIFACTS=1", "WIRE_OUTPUT_CACHE=0"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := outputCacheEnabled(ctx, t.TempDir(), tt.env); got != tt.want { + t.Fatalf("outputCacheEnabled(..., %v) = %v, want %v", tt.env, got, tt.want) + } + }) + } +} From 8dbd59013ae0091cc8f1f7d272884beaddc73c1c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:10:55 -0500 Subject: [PATCH 58/82] refactor: make provider set fallback policy explicit --- internal/wire/parse.go | 43 +++++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 834f8c2..cd20bd5 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -538,8 +538,8 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) - if spec == nil && isProviderSetType(obj.Type()) { - if pset, ok, errs := oc.semanticProviderSet(obj); ok { + if isProviderSetType(obj.Type()) { + if pset, ok, errs := oc.providerSetForVar(obj, spec); ok { return pset, errs } } @@ -561,6 +561,13 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } +func (oc *objectCache) providerSetForVar(obj *types.Var, spec *ast.ValueSpec) (*ProviderSet, bool, []error) { + if spec != nil { + return nil, false, nil + } + return oc.semanticProviderSet(obj) +} + func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { @@ -1213,22 +1220,10 @@ func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { } func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { - pos := obj.Pos() - for _, f := range pkg.Syntax { - tokenFile := pkg.Fset.File(f.Pos()) - if tokenFile == nil { - continue - } - if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { - path, _ := astutil.PathEnclosingInterval(f, pos, pos) - for _, node := range path { - if spec, ok := node.(*ast.ValueSpec); ok { - return spec - } - } - } + if pkg == nil { + return nil } - return nil + return valueSpecForVar(pkg.Fset, pkg.Syntax, obj) } // varDecl finds the declaration that defines the given variable. @@ -1236,9 +1231,19 @@ func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. // Recommended by https://golang.org/s/types-tutorial pkg := oc.packages[obj.Pkg().Path()] + if pkg == nil { + return nil + } + return valueSpecForVar(oc.fset, pkg.Syntax, obj) +} + +func valueSpecForVar(fset *token.FileSet, files []*ast.File, obj *types.Var) *ast.ValueSpec { pos := obj.Pos() - for _, f := range pkg.Syntax { - tokenFile := oc.fset.File(f.Pos()) + for _, f := range files { + tokenFile := fset.File(f.Pos()) + if tokenFile == nil { + continue + } if base := tokenFile.Base(); base <= int(pos) && int(pos) < base+tokenFile.Size() { path, _ := astutil.PathEnclosingInterval(f, pos, pos) for _, node := range path { From 9ef6b11db20058188d10f75fb4cddb214ac6c438 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:11:53 -0500 Subject: [PATCH 59/82] refactor: share custom loader root loading path --- internal/loader/custom.go | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 3529350..634545d 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -259,18 +259,13 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz fset = token.NewFileSet() } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) - prefetchStart := time.Now() - l.prefetchArtifacts() - l.stats.artifactPrefetch = time.Since(prefetchStart) - rootLoadStart := time.Now() - root, err := l.loadPackage(req.Package) + roots, err := loadCustomRootPackages(l, []string{req.Package}) if err != nil { return nil, err } - l.stats.rootLoad = time.Since(rootLoadStart) logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ - Packages: []*packages.Package{root}, + Packages: roots, Backend: ModeCustom, }, nil } @@ -303,13 +298,26 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, unsupportedError{reason: "no root packages from metadata"} } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) + rootPaths := nonDepRootImportPaths(meta) + roots, err := loadCustomRootPackages(l, rootPaths) + if err != nil { + return nil, err + } + logTypedLoadStats(ctx, "typed", l.stats) + return &PackageLoadResult{ + Packages: roots, + Backend: ModeCustom, + }, nil +} + +func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() l.stats.artifactPrefetch = time.Since(prefetchStart) + rootLoadStart := time.Now() - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { + roots := make([]*packages.Package, 0, len(paths)) + for _, path := range paths { root, err := l.loadPackage(path) if err != nil { return nil, err @@ -318,11 +326,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo } l.stats.rootLoad = time.Since(rootLoadStart) sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - logTypedLoadStats(ctx, "typed", l.stats) - return &PackageLoadResult{ - Packages: roots, - Backend: ModeCustom, - }, nil + return roots, nil } func (v *customValidator) validatePackage(path string) (*packages.Package, error) { From 4386089cdff454a9d32b4ae131bb6b2f2affc8d9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:12:51 -0500 Subject: [PATCH 60/82] refactor: share custom loader metadata root graph --- internal/loader/custom.go | 67 ++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 634545d..669e3f1 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -191,31 +191,11 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes if len(meta) == 0 { return nil, unsupportedError{reason: "empty go list result"} } - pkgs := make(map[string]*packages.Package, len(meta)) - for path, m := range meta { - pkgs[path] = packageStub(nil, m) - appendPackageMetaError(pkgs[path], m) - } - for path, m := range meta { - pkg := pkgs[path] - for _, imp := range m.Imports { - target := resolvedImportTarget(m, imp) - if dep := pkgs[target]; dep != nil { - pkg.Imports[imp] = dep - } - } - } - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { - if pkg := pkgs[path]; pkg != nil { - roots = append(roots, pkg) - } - } + pkgs := packageStubGraphFromMeta(nil, meta) + roots := rootPackagesFromMeta(meta, pkgs) if len(roots) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } - sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ Packages: roots, Backend: ModeCustom, @@ -290,10 +270,7 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if fset == nil { fset = token.NewFileSet() } - targets := make(map[string]struct{}) - for _, path := range nonDepRootImportPaths(meta) { - targets[path] = struct{}{} - } + targets := rootTargetSet(meta) if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } @@ -1220,6 +1197,24 @@ func packageStub(fset *token.FileSet, meta *packageMeta) *packages.Package { } } +func packageStubGraphFromMeta(fset *token.FileSet, meta map[string]*packageMeta) map[string]*packages.Package { + pkgs := make(map[string]*packages.Package, len(meta)) + for path, m := range meta { + pkgs[path] = packageStub(fset, m) + appendPackageMetaError(pkgs[path], m) + } + for path, m := range meta { + pkg := pkgs[path] + for _, imp := range m.Imports { + target := resolvedImportTarget(m, imp) + if dep := pkgs[target]; dep != nil { + pkg.Imports[imp] = dep + } + } + } + return pkgs +} + func appendPackageMetaError(pkg *packages.Package, meta *packageMeta) bool { if pkg == nil || meta == nil || meta.Error == nil || strings.TrimSpace(meta.Error.Err) == "" { return false @@ -1288,6 +1283,26 @@ func nonDepRootImportPaths(meta map[string]*packageMeta) []string { return paths } +func rootTargetSet(meta map[string]*packageMeta) map[string]struct{} { + targets := make(map[string]struct{}) + for _, path := range nonDepRootImportPaths(meta) { + targets[path] = struct{}{} + } + return targets +} + +func rootPackagesFromMeta(meta map[string]*packageMeta, pkgs map[string]*packages.Package) []*packages.Package { + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { + roots = append(roots, pkg) + } + } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) + return roots +} + func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { prefix := "loader.custom." + mode logDuration(ctx, prefix+".read_files.cumulative", stats.read) From b099db8f1a2d65f3a24b3484872c79c9376698fe Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:13:22 -0500 Subject: [PATCH 61/82] refactor: isolate semantic provider set support rule --- internal/wire/parse.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index cd20bd5..f56d2c1 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -606,12 +606,19 @@ func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcach if !ok { return semanticcache.ProviderSetArtifact{}, false } + if !semanticProviderSetArtifactSupported(setArt) { + return semanticcache.ProviderSetArtifact{}, false + } + return setArt, true +} + +func semanticProviderSetArtifactSupported(setArt semanticcache.ProviderSetArtifact) bool { for _, item := range setArt.Items { if item.Kind == "bind" { - return semanticcache.ProviderSetArtifact{}, false + return false } } - return setArt, true + return true } func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { From cf4bb809a4fa813f1108c372ede3cc46685137ce Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:26:52 -0500 Subject: [PATCH 62/82] refactor: fold back weak cleanup abstractions --- internal/loader/custom.go | 36 +++++++++++++----------------------- internal/wire/parse.go | 10 ++-------- 2 files changed, 15 insertions(+), 31 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 669e3f1..007eebd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -192,10 +192,17 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes return nil, unsupportedError{reason: "empty go list result"} } pkgs := packageStubGraphFromMeta(nil, meta) - roots := rootPackagesFromMeta(meta, pkgs) + rootPaths := nonDepRootImportPaths(meta) + roots := make([]*packages.Package, 0, len(rootPaths)) + for _, path := range rootPaths { + if pkg := pkgs[path]; pkg != nil { + roots = append(roots, pkg) + } + } if len(roots) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } + sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) return &RootLoadResult{ Packages: roots, Backend: ModeCustom, @@ -270,12 +277,15 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if fset == nil { fset = token.NewFileSet() } - targets := rootTargetSet(meta) + rootPaths := nonDepRootImportPaths(meta) + targets := make(map[string]struct{}, len(rootPaths)) + for _, path := range rootPaths { + targets[path] = struct{}{} + } if len(targets) == 0 { return nil, unsupportedError{reason: "no root packages from metadata"} } l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) - rootPaths := nonDepRootImportPaths(meta) roots, err := loadCustomRootPackages(l, rootPaths) if err != nil { return nil, err @@ -1283,26 +1293,6 @@ func nonDepRootImportPaths(meta map[string]*packageMeta) []string { return paths } -func rootTargetSet(meta map[string]*packageMeta) map[string]struct{} { - targets := make(map[string]struct{}) - for _, path := range nonDepRootImportPaths(meta) { - targets[path] = struct{}{} - } - return targets -} - -func rootPackagesFromMeta(meta map[string]*packageMeta, pkgs map[string]*packages.Package) []*packages.Package { - rootPaths := nonDepRootImportPaths(meta) - roots := make([]*packages.Package, 0, len(rootPaths)) - for _, path := range rootPaths { - if pkg := pkgs[path]; pkg != nil { - roots = append(roots, pkg) - } - } - sort.Slice(roots, func(i, j int) bool { return roots[i].PkgPath < roots[j].PkgPath }) - return roots -} - func logTypedLoadStats(ctx context.Context, mode string, stats typedLoadStats) { prefix := "loader.custom." + mode logDuration(ctx, prefix+".read_files.cumulative", stats.read) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index f56d2c1..66d51da 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -539,7 +539,8 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { case *types.Var: spec := oc.varDecl(obj) if isProviderSetType(obj.Type()) { - if pset, ok, errs := oc.providerSetForVar(obj, spec); ok { + if spec == nil { + pset, _, errs := oc.semanticProviderSet(obj) return pset, errs } } @@ -561,13 +562,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } -func (oc *objectCache) providerSetForVar(obj *types.Var, spec *ast.ValueSpec) (*ProviderSet, bool, []error) { - if spec != nil { - return nil, false, nil - } - return oc.semanticProviderSet(obj) -} - func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { setArt, ok := oc.semanticProviderSetArtifact(obj) if !ok { From c323a0f7fdc379cc1330d8e00ccc4ffc56b09914 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:32:44 -0500 Subject: [PATCH 63/82] refactor: narrow loader semantic artifact coupling --- internal/loader/artifact_cache.go | 8 ++++++-- internal/loader/custom.go | 16 ++++++++++++++++ internal/loader/loader_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index e920d5a..f57ba76 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -153,9 +153,13 @@ func isProviderSetTypeForLoader(t types.Type) bool { if obj == nil || obj.Pkg() == nil { return false } - switch obj.Pkg().Path() { + return isWireImportPath(obj.Pkg().Path()) && obj.Name() == "ProviderSet" +} + +func isWireImportPath(path string) bool { + switch path { case "github.com/goforj/wire", "github.com/google/wire": - return obj.Name() == "ProviderSet" + return true default: return false } diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 007eebd..0cdd919 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -546,6 +546,18 @@ func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMet return art.Supported } +func localPackageNeedsSemanticArtifacts(meta *packageMeta) bool { + if meta == nil { + return false + } + for _, path := range meta.Imports { + if isWireImportPath(path) { + return true + } + } + return false +} + func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -555,6 +567,10 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy.read = true return policy } + if !localPackageNeedsSemanticArtifacts(meta) { + policy.read = true + return policy + } policy.read = l.localSemanticArtifactSupported(meta) return policy } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 05cfaa7..8e5e585 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2386,6 +2386,34 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } +func TestArtifactPolicyLocalReadOnlyNeedsSemanticForWirePackages(t *testing.T) { + t.Parallel() + + loader := &customTypedGraphLoader{ + env: []string{"WIRE_LOADER_ARTIFACTS=1"}, + localSemanticOK: map[string]bool{"example.com/app": false}, + } + + nonWireMeta := &packageMeta{ + ImportPath: "example.com/app", + Imports: []string{"fmt", "example.com/dep"}, + } + wireMeta := &packageMeta{ + ImportPath: "example.com/app", + Imports: []string{"github.com/goforj/wire"}, + } + + if got := loader.artifactPolicy(nonWireMeta, false, true); !got.read || !got.write { + t.Fatalf("artifactPolicy(non-wire local) = %+v, want read+write", got) + } + if got := loader.artifactPolicy(wireMeta, false, true); got.read || !got.write { + t.Fatalf("artifactPolicy(wire local without semantic support) = %+v, want write-only", got) + } + if got := loader.artifactPolicy(wireMeta, false, false); !got.read || !got.write { + t.Fatalf("artifactPolicy(wire external) = %+v, want read+write", got) + } +} + func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") From 7bf31e8accdeefa14f1fc1b4269c80472e21bebc Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:39:52 -0500 Subject: [PATCH 64/82] refactor: unify semantic provider set support rules --- internal/wire/parse.go | 3 +++ internal/wire/parse_coverage_test.go | 19 +++++++++++++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 66d51da..2c33d99 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -1043,6 +1043,9 @@ func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath strin } setArt.Items = append(setArt.Items, items...) } + if !semanticProviderSetArtifactSupported(setArt) { + return semanticcache.ProviderSetArtifact{}, false + } return setArt, true } diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index c3c4d8e..8b7e68a 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -299,14 +299,25 @@ func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, }, } + if got, ok := summarizeSemanticProviderSet(info, call, "example.com/app"); ok || len(got.Items) != 0 { + t.Fatalf("summarizeSemanticProviderSet(bind case) = (%+v, %v), want unsupported", got, ok) + } + + call = &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, + Args: []ast.Expr{ + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, + &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, + }, + } got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") if !ok { - t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") + t.Fatal("summarizeSemanticProviderSet(non-bind type-only forms) = unsupported, want supported") } - if len(got.Items) != 3 { - t.Fatalf("items len = %d, want 3", len(got.Items)) + if len(got.Items) != 2 { + t.Fatalf("items len = %d, want 2", len(got.Items)) } - if got.Items[0].Kind != "bind" || got.Items[1].Kind != "struct" || got.Items[2].Kind != "fields" { + if got.Items[0].Kind != "struct" || got.Items[1].Kind != "fields" { t.Fatalf("unexpected kinds: %+v", got.Items) } } From a7fc49c9993dc7fd2db747895f7a92c737f89fff Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 20:51:06 -0500 Subject: [PATCH 65/82] fix: restore local loader artifact safety gate --- internal/loader/artifact_cache.go | 8 ++------ internal/loader/custom.go | 16 ---------------- internal/loader/loader_test.go | 28 ---------------------------- 3 files changed, 2 insertions(+), 50 deletions(-) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index f57ba76..e920d5a 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -153,13 +153,9 @@ func isProviderSetTypeForLoader(t types.Type) bool { if obj == nil || obj.Pkg() == nil { return false } - return isWireImportPath(obj.Pkg().Path()) && obj.Name() == "ProviderSet" -} - -func isWireImportPath(path string) bool { - switch path { + switch obj.Pkg().Path() { case "github.com/goforj/wire", "github.com/google/wire": - return true + return obj.Name() == "ProviderSet" default: return false } diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 0cdd919..007eebd 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -546,18 +546,6 @@ func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMet return art.Supported } -func localPackageNeedsSemanticArtifacts(meta *packageMeta) bool { - if meta == nil { - return false - } - for _, path := range meta.Imports { - if isWireImportPath(path) { - return true - } - } - return false -} - func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -567,10 +555,6 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy.read = true return policy } - if !localPackageNeedsSemanticArtifacts(meta) { - policy.read = true - return policy - } policy.read = l.localSemanticArtifactSupported(meta) return policy } diff --git a/internal/loader/loader_test.go b/internal/loader/loader_test.go index 8e5e585..05cfaa7 100644 --- a/internal/loader/loader_test.go +++ b/internal/loader/loader_test.go @@ -2386,34 +2386,6 @@ func TestLoaderArtifactKeyExternalWithoutExportChangesWhenSourceChanges(t *testi } } -func TestArtifactPolicyLocalReadOnlyNeedsSemanticForWirePackages(t *testing.T) { - t.Parallel() - - loader := &customTypedGraphLoader{ - env: []string{"WIRE_LOADER_ARTIFACTS=1"}, - localSemanticOK: map[string]bool{"example.com/app": false}, - } - - nonWireMeta := &packageMeta{ - ImportPath: "example.com/app", - Imports: []string{"fmt", "example.com/dep"}, - } - wireMeta := &packageMeta{ - ImportPath: "example.com/app", - Imports: []string{"github.com/goforj/wire"}, - } - - if got := loader.artifactPolicy(nonWireMeta, false, true); !got.read || !got.write { - t.Fatalf("artifactPolicy(non-wire local) = %+v, want read+write", got) - } - if got := loader.artifactPolicy(wireMeta, false, true); got.read || !got.write { - t.Fatalf("artifactPolicy(wire local without semantic support) = %+v, want write-only", got) - } - if got := loader.artifactPolicy(wireMeta, false, false); !got.read || !got.write { - t.Fatalf("artifactPolicy(wire external) = %+v, want read+write", got) - } -} - func TestRunGoListIncludesExportDataForReplacedModule(t *testing.T) { root := t.TempDir() depRoot := filepath.Join(root, "depmod") From df1f6f62d0a211079b98e357167dcea21ca22fd9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 21:36:10 -0500 Subject: [PATCH 66/82] refactor: disable semantic reconstruction by default --- internal/wire/parse.go | 24 ++++++++++++-- internal/wire/parse_coverage_test.go | 2 ++ internal/wire/semantic_reconstruction_test.go | 33 +++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 internal/wire/semantic_reconstruction_test.go diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2c33d99..747cf61 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -36,6 +36,8 @@ import ( "github.com/goforj/wire/internal/semanticcache" ) +const semanticReconstructionEnv = "WIRE_SEMANTIC_RECONSTRUCTION" + // A providerSetSrc captures the source for a type provided by a ProviderSet. // Exactly one of the fields will be set. type providerSetSrc struct { @@ -491,7 +493,9 @@ func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) - oc.recordSemanticArtifacts() + if semanticReconstructionEnabled(env) { + oc.recordSemanticArtifacts() + } return oc } @@ -588,6 +592,9 @@ func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, } func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { + if !semanticReconstructionEnabled(oc.env) { + return semanticcache.ProviderSetArtifact{}, false + } pkg := oc.packages[obj.Pkg().Path()] if pkg == nil { return semanticcache.ProviderSetArtifact{}, false @@ -926,6 +933,9 @@ func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, i } func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { + if !semanticReconstructionEnabled(oc.env) { + return nil + } if pkg == nil { return nil } @@ -945,7 +955,7 @@ func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.Pa } func (oc *objectCache) recordSemanticArtifacts() { - if len(oc.env) == 0 { + if len(oc.env) == 0 || !semanticReconstructionEnabled(oc.env) { return } for _, pkg := range oc.packages { @@ -962,6 +972,16 @@ func (oc *objectCache) recordSemanticArtifacts() { } } +func semanticReconstructionEnabled(env []string) bool { + for i := len(env) - 1; i >= 0; i-- { + key, value, ok := strings.Cut(env[i], "=") + if ok && key == semanticReconstructionEnv { + return value == "1" + } + } + return false +} + func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { return "", "", nil, false diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 8b7e68a..4ba2a30 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -345,6 +345,7 @@ func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { Imports: make(map[string]*packages.Package), } oc := &objectCache{ + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, fset: fset, packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, objects: make(map[objRef]objCacheEntry), @@ -406,6 +407,7 @@ func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { Imports: make(map[string]*packages.Package), } oc := &objectCache{ + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, fset: fset, packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, objects: make(map[objRef]objCacheEntry), diff --git a/internal/wire/semantic_reconstruction_test.go b/internal/wire/semantic_reconstruction_test.go new file mode 100644 index 0000000..79cbe3c --- /dev/null +++ b/internal/wire/semantic_reconstruction_test.go @@ -0,0 +1,33 @@ +package wire + +import "testing" + +func TestSemanticReconstructionEnabled(t *testing.T) { + tests := []struct { + name string + env []string + want bool + }{ + { + name: "disabled by default", + want: false, + }, + { + name: "enabled by env", + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, + want: true, + }, + { + name: "disabled by env", + env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=0"}, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := semanticReconstructionEnabled(tt.env); got != tt.want { + t.Fatalf("semanticReconstructionEnabled(%v) = %v, want %v", tt.env, got, tt.want) + } + }) + } +} From 3927014faae17b14931efa46ed23296f5631a8b4 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 22:20:11 -0500 Subject: [PATCH 67/82] refactor: remove semantic reconstruction path --- internal/wire/parse.go | 570 +----------------- internal/wire/parse_coverage_test.go | 234 ------- internal/wire/semantic_reconstruction_test.go | 33 - internal/wire/wire.go | 2 +- 4 files changed, 2 insertions(+), 837 deletions(-) delete mode 100644 internal/wire/semantic_reconstruction_test.go diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 747cf61..4350baa 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -33,11 +33,8 @@ import ( "golang.org/x/tools/go/types/typeutil" "github.com/goforj/wire/internal/loader" - "github.com/goforj/wire/internal/semanticcache" ) -const semanticReconstructionEnv = "WIRE_SEMANTIC_RECONSTRUCTION" - // A providerSetSrc captures the source for a type provided by a ProviderSet. // Exactly one of the fields will be set. type providerSetSrc struct { @@ -270,7 +267,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] Fset: fset, Sets: make(map[ProviderSetID]*ProviderSet), } - oc := newObjectCacheWithEnv(pkgs, env) + oc := newObjectCache(pkgs) ec := new(errorCollector) for _, pkg := range pkgs { if isWireImport(pkg.PkgPath) { @@ -455,10 +452,8 @@ func (in *Injector) String() string { // objectCache is a lazily evaluated mapping of objects to Wire structures. type objectCache struct { fset *token.FileSet - env []string packages map[string]*packages.Package objects map[objRef]objCacheEntry - semantic map[string]*semanticcache.PackageArtifact hasher typeutil.Hasher } @@ -473,19 +468,13 @@ type objCacheEntry struct { } func newObjectCache(pkgs []*packages.Package) *objectCache { - return newObjectCacheWithEnv(pkgs, nil) -} - -func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache { if len(pkgs) == 0 { panic("object cache must have packages to draw from") } oc := &objectCache{ fset: pkgs[0].Fset, - env: append([]string(nil), env...), packages: make(map[string]*packages.Package), objects: make(map[objRef]objCacheEntry), - semantic: make(map[string]*semanticcache.PackageArtifact), hasher: typeutil.MakeHasher(), } // Depth-first search of all dependencies to gather import path to @@ -493,9 +482,6 @@ func newObjectCacheWithEnv(pkgs []*packages.Package, env []string) *objectCache // call to packages.Load and an import path X, there will exist only // one *packages.Package value with PkgPath X. oc.registerPackages(pkgs, false) - if semanticReconstructionEnabled(env) { - oc.recordSemanticArtifacts() - } return oc } @@ -542,12 +528,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { switch obj := obj.(type) { case *types.Var: spec := oc.varDecl(obj) - if isProviderSetType(obj.Type()) { - if spec == nil { - pset, _, errs := oc.semanticProviderSet(obj) - return pset, errs - } - } if spec == nil || len(spec.Values) == 0 { return nil, []error{fmt.Errorf("%v is not a provider or a provider set", obj)} } @@ -566,196 +546,6 @@ func (oc *objectCache) get(obj types.Object) (val interface{}, errs []error) { } } -func (oc *objectCache) semanticProviderSet(obj *types.Var) (*ProviderSet, bool, []error) { - setArt, ok := oc.semanticProviderSetArtifact(obj) - if !ok { - return nil, false, nil - } - pset := &ProviderSet{ - Pos: obj.Pos(), - PkgPath: obj.Pkg().Path(), - VarName: obj.Name(), - } - ec := new(errorCollector) - for _, item := range setArt.Items { - if errs := oc.applySemanticProviderSetItem(pset, item); len(errs) > 0 { - ec.add(errs...) - } - } - if len(ec.errors) > 0 { - return nil, true, ec.errors - } - if errs := oc.finalizeProviderSet(pset); len(errs) > 0 { - return nil, true, errs - } - return pset, true, nil -} - -func (oc *objectCache) semanticProviderSetArtifact(obj *types.Var) (semanticcache.ProviderSetArtifact, bool) { - if !semanticReconstructionEnabled(oc.env) { - return semanticcache.ProviderSetArtifact{}, false - } - pkg := oc.packages[obj.Pkg().Path()] - if pkg == nil { - return semanticcache.ProviderSetArtifact{}, false - } - art := oc.semanticArtifact(pkg) - if art == nil || !art.Supported { - return semanticcache.ProviderSetArtifact{}, false - } - setArt, ok := art.Vars[obj.Name()] - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - if !semanticProviderSetArtifactSupported(setArt) { - return semanticcache.ProviderSetArtifact{}, false - } - return setArt, true -} - -func semanticProviderSetArtifactSupported(setArt semanticcache.ProviderSetArtifact) bool { - for _, item := range setArt.Items { - if item.Kind == "bind" { - return false - } - } - return true -} - -func (oc *objectCache) applySemanticProviderSetItem(pset *ProviderSet, item semanticcache.ProviderSetItemArtifact) []error { - switch item.Kind { - case "func": - providerObj, errs := oc.semanticProvider(item.ImportPath, item.Name) - if len(errs) > 0 { - return errs - } - pset.Providers = append(pset.Providers, providerObj) - return nil - case "set": - setObj, errs := oc.semanticImportedSet(item.ImportPath, item.Name) - if len(errs) > 0 { - return errs - } - pset.Imports = append(pset.Imports, setObj) - return nil - case "bind": - binding, errs := oc.semanticBinding(item) - if len(errs) > 0 { - return errs - } - pset.Bindings = append(pset.Bindings, binding) - return nil - case "struct": - providerObj, errs := oc.semanticStructProvider(item) - if len(errs) > 0 { - return errs - } - pset.Providers = append(pset.Providers, providerObj) - return nil - case "fields": - fields, errs := oc.semanticFields(item) - if len(errs) > 0 { - return errs - } - pset.Fields = append(pset.Fields, fields...) - return nil - default: - return []error{fmt.Errorf("unsupported semantic cache item kind %q", item.Kind)} - } -} - -func (oc *objectCache) semanticProvider(importPath, name string) (*Provider, []error) { - fn, err := oc.lookupPackageFunc(importPath, name) - if err != nil { - return nil, []error{err} - } - return processFuncProvider(oc.fset, fn) -} - -func (oc *objectCache) semanticImportedSet(importPath, name string) (*ProviderSet, []error) { - v, err := oc.lookupProviderSetVar(importPath, name) - if err != nil { - return nil, []error{err} - } - item, errs := oc.get(v) - if len(errs) > 0 { - return nil, errs - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - return nil, []error{fmt.Errorf("%s.%s did not resolve to a provider set", importPath, name)} - } - return pset, nil -} - -func (oc *objectCache) semanticBinding(item semanticcache.ProviderSetItemArtifact) (*IfaceBinding, []error) { - iface, err := oc.semanticType(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - provided, err := oc.semanticType(item.Type2) - if err != nil { - return nil, semanticErrors(err) - } - return &IfaceBinding{ - Iface: iface, - Provided: provided, - }, nil -} - -func (oc *objectCache) semanticStructProvider(item semanticcache.ProviderSetItemArtifact) (*Provider, []error) { - typeName, err := oc.semanticTypeName(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - out, st, ok := namedStructType(typeName) - if !ok { - return nil, semanticErrors(fmt.Errorf("%s.%s does not name a struct", item.Type.ImportPath, item.Type.Name)) - } - provider := newStructProvider(typeName, typeAndPointer(out)) - args, errs := semanticStructProviderInputs(st, item) - if len(errs) > 0 { - return nil, errs - } - provider.Args = args - return provider, nil -} - -func (oc *objectCache) semanticFields(item semanticcache.ProviderSetItemArtifact) ([]*Field, []error) { - parent, err := oc.semanticType(item.Type) - if err != nil { - return nil, semanticErrors(err) - } - structType, ptrToField, err := structFromFieldsParent(parent) - if err != nil { - return nil, semanticErrors(err) - } - fields := make([]*Field, 0, len(item.FieldNames)) - for _, fieldName := range item.FieldNames { - v, err := requiredStructField(structType, fieldName) - if err != nil { - return nil, semanticErrors(err) - } - fields = append(fields, newField(parent, v, ptrToField)) - } - return fields, nil -} - -func semanticStructProviderInputs(st *types.Struct, item semanticcache.ProviderSetItemArtifact) ([]ProviderInput, []error) { - if item.AllFields { - return providerInputsForAllowedStructFields(st), nil - } - fields := make([]*types.Var, 0, len(item.FieldNames)) - for _, fieldName := range item.FieldNames { - f, err := requiredStructField(st, fieldName) - if err != nil { - return nil, []error{fmt.Errorf("field %q not found in %s.%s", fieldName, item.Type.ImportPath, item.Type.Name)} - } - fields = append(fields, f) - } - return providerInputsForVars(fields), nil -} - func providerInputsForAllowedStructFields(st *types.Struct) []ProviderInput { fields := make([]*types.Var, 0, st.NumFields()) for i := 0; i < st.NumFields(); i++ { @@ -775,10 +565,6 @@ func providerInputsForVars(vars []*types.Var) []ProviderInput { return args } -func semanticErrors(err error) []error { - return []error{err} -} - func providerInputForVar(v *types.Var) ProviderInput { return ProviderInput{ Type: v.Type(), @@ -818,18 +604,6 @@ func newStructProvider(typeName types.Object, out []types.Type) *Provider { } } -func (oc *objectCache) semanticType(ref semanticcache.TypeRef) (types.Type, error) { - typeName, err := oc.semanticTypeName(ref) - if err != nil { - return nil, err - } - return applyTypePointers(typeName.Type(), ref.Pointer), nil -} - -func (oc *objectCache) semanticTypeName(ref semanticcache.TypeRef) (*types.TypeName, error) { - return oc.lookupPackageTypeName(ref.ImportPath, ref.Name) -} - func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Object, error) { pkg := oc.packages[importPath] if pkg == nil || pkg.Types == nil { @@ -838,18 +612,6 @@ func (oc *objectCache) lookupPackageObject(importPath, name string) (types.Objec return pkg.Types.Scope().Lookup(name), nil } -func (oc *objectCache) lookupPackageTypeName(importPath, name string) (*types.TypeName, error) { - obj, err := oc.lookupPackageObject(importPath, name) - if err != nil { - return nil, err - } - typeName, ok := obj.(*types.TypeName) - if !ok || typeName == nil { - return nil, fmt.Errorf("%s.%s is not a named type", importPath, name) - } - return typeName, nil -} - func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, error) { obj, err := oc.lookupPackageObject(importPath, name) if err != nil { @@ -862,18 +624,6 @@ func (oc *objectCache) lookupPackageFunc(importPath, name string) (*types.Func, return fn, nil } -func (oc *objectCache) lookupProviderSetVar(importPath, name string) (*types.Var, error) { - obj, err := oc.lookupPackageObject(importPath, name) - if err != nil { - return nil, err - } - v, ok := obj.(*types.Var) - if !ok || v == nil || !isProviderSetType(v.Type()) { - return nil, fmt.Errorf("%s.%s is not a provider set", importPath, name) - } - return v, nil -} - func applyTypePointers(typ types.Type, count int) types.Type { for i := 0; i < count; i++ { typ = types.NewPointer(typ) @@ -932,324 +682,6 @@ func lookupQuotedStructField(st *types.Struct, quotedName string) (*types.Var, i return nil, -1 } -func (oc *objectCache) semanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { - if !semanticReconstructionEnabled(oc.env) { - return nil - } - if pkg == nil { - return nil - } - if art, ok := oc.semantic[pkg.PkgPath]; ok { - return art - } - importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) - if !ok { - return nil - } - art, err := readSemanticArtifact(oc.env, importPath, packageName, files) - if err != nil { - return nil - } - oc.semantic[pkg.PkgPath] = art - return art -} - -func (oc *objectCache) recordSemanticArtifacts() { - if len(oc.env) == 0 || !semanticReconstructionEnabled(oc.env) { - return - } - for _, pkg := range oc.packages { - importPath, packageName, files, ok := semanticArtifactInputs(oc.env, pkg) - if !ok || len(pkg.Syntax) == 0 || pkg.Types == nil || pkg.TypesInfo == nil { - continue - } - art := buildSemanticArtifact(pkg) - if art == nil { - continue - } - oc.semantic[pkg.PkgPath] = art - _ = writeSemanticArtifact(oc.env, importPath, packageName, files, art) - } -} - -func semanticReconstructionEnabled(env []string) bool { - for i := len(env) - 1; i >= 0; i-- { - key, value, ok := strings.Cut(env[i], "=") - if ok && key == semanticReconstructionEnv { - return value == "1" - } - } - return false -} - -func semanticArtifactInputs(env []string, pkg *packages.Package) (importPath, packageName string, files []string, ok bool) { - if len(env) == 0 || pkg == nil || len(pkg.GoFiles) == 0 { - return "", "", nil, false - } - return pkg.PkgPath, pkg.Name, pkg.GoFiles, true -} - -func readSemanticArtifact(env []string, importPath, packageName string, files []string) (*semanticcache.PackageArtifact, error) { - return semanticcache.Read(env, importPath, packageName, files) -} - -func writeSemanticArtifact(env []string, importPath, packageName string, files []string, art *semanticcache.PackageArtifact) error { - return semanticcache.Write(env, importPath, packageName, files, art) -} - -func buildSemanticArtifact(pkg *packages.Package) *semanticcache.PackageArtifact { - if pkg == nil || pkg.Types == nil || pkg.TypesInfo == nil { - return nil - } - art := &semanticcache.PackageArtifact{ - Version: 1, - PackagePath: pkg.PkgPath, - PackageName: pkg.Name, - Supported: true, - Vars: make(map[string]semanticcache.ProviderSetArtifact), - } - scope := pkg.Types.Scope() - for _, name := range scope.Names() { - obj := scope.Lookup(name) - v, ok := obj.(*types.Var) - if !ok || !isProviderSetType(v.Type()) { - continue - } - art.HasProviderSetVars = true - spec := semanticVarDecl(pkg, v) - if spec == nil || len(spec.Values) == 0 { - art.Supported = false - continue - } - var idx int - found := false - for i := range spec.Names { - if spec.Names[i].Name == v.Name() { - idx = i - found = true - break - } - } - if !found || idx >= len(spec.Values) { - art.Supported = false - continue - } - setArt, ok := summarizeSemanticProviderSet(pkg.TypesInfo, spec.Values[idx], pkg.PkgPath) - if !ok { - art.Supported = false - continue - } - art.Vars[v.Name()] = setArt - } - return art -} - -func summarizeSemanticProviderSet(info *types.Info, expr ast.Expr, pkgPath string) (semanticcache.ProviderSetArtifact, bool) { - call, ok := astutil.Unparen(expr).(*ast.CallExpr) - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - fnObj := qualifiedIdentObject(info, call.Fun) - if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) || fnObj.Name() != "NewSet" { - return semanticcache.ProviderSetArtifact{}, false - } - setArt := semanticcache.ProviderSetArtifact{ - Items: make([]semanticcache.ProviderSetItemArtifact, 0, len(call.Args)), - } - for _, arg := range call.Args { - items, ok := summarizeSemanticProviderSetArg(info, astutil.Unparen(arg), pkgPath) - if !ok { - return semanticcache.ProviderSetArtifact{}, false - } - setArt.Items = append(setArt.Items, items...) - } - if !semanticProviderSetArtifactSupported(setArt) { - return semanticcache.ProviderSetArtifact{}, false - } - return setArt, true -} - -func summarizeSemanticProviderSetArg(info *types.Info, expr ast.Expr, pkgPath string) ([]semanticcache.ProviderSetItemArtifact, bool) { - if obj := qualifiedIdentObject(info, expr); obj != nil && obj.Pkg() != nil && obj.Exported() { - item := semanticcache.ProviderSetItemArtifact{ - ImportPath: obj.Pkg().Path(), - Name: obj.Name(), - } - switch typed := obj.(type) { - case *types.Func: - item.Kind = "func" - case *types.Var: - if !isProviderSetType(typed.Type()) { - return nil, false - } - item.Kind = "set" - default: - return nil, false - } - if item.ImportPath == "" { - item.ImportPath = pkgPath - } - return []semanticcache.ProviderSetItemArtifact{item}, true - } - call, ok := expr.(*ast.CallExpr) - if !ok { - return nil, false - } - fnObj := qualifiedIdentObject(info, call.Fun) - if fnObj == nil || fnObj.Pkg() == nil || !isWireImport(fnObj.Pkg().Path()) { - return nil, false - } - switch fnObj.Name() { - case "NewSet": - nested, ok := summarizeSemanticProviderSet(info, call, pkgPath) - if !ok { - return nil, false - } - return nested.Items, true - case "Bind": - item, ok := summarizeSemanticBind(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - case "Struct": - item, ok := summarizeSemanticStruct(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - case "FieldsOf": - item, ok := summarizeSemanticFields(info, call) - if !ok { - return nil, false - } - return []semanticcache.ProviderSetItemArtifact{item}, true - default: - return nil, false - } -} - -func summarizeSemanticBind(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) != 2 { - return semanticcache.ProviderSetItemArtifact{}, false - } - iface, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) - if !ok || iface.Pointer == 0 { - return semanticcache.ProviderSetItemArtifact{}, false - } - iface.Pointer-- - providedType := info.TypeOf(call.Args[1]) - if bindShouldUsePointer(info, call) { - ptr, ok := providedType.(*types.Pointer) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - providedType = ptr.Elem() - } - provided, ok := summarizeTypeRef(providedType) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - return semanticcache.ProviderSetItemArtifact{ - Kind: "bind", - Type: iface, - Type2: provided, - }, true -} - -func summarizeSemanticStruct(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) < 1 { - return semanticcache.ProviderSetItemArtifact{}, false - } - structType := info.TypeOf(call.Args[0]) - ptr, ok := structType.(*types.Pointer) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - ref, ok := summarizeTypeRef(ptr.Elem()) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - item := semanticcache.ProviderSetItemArtifact{ - Kind: "struct", - Type: ref, - } - if allFields(call) { - item.AllFields = true - return item, true - } - item.FieldNames = make([]string, 0, len(call.Args)-1) - for i := 1; i < len(call.Args); i++ { - lit, ok := call.Args[i].(*ast.BasicLit) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - fieldName, err := strconv.Unquote(lit.Value) - if err != nil { - return semanticcache.ProviderSetItemArtifact{}, false - } - item.FieldNames = append(item.FieldNames, fieldName) - } - return item, true -} - -func summarizeSemanticFields(info *types.Info, call *ast.CallExpr) (semanticcache.ProviderSetItemArtifact, bool) { - if len(call.Args) < 2 { - return semanticcache.ProviderSetItemArtifact{}, false - } - parent, ok := summarizeTypeRef(info.TypeOf(call.Args[0])) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - item := semanticcache.ProviderSetItemArtifact{ - Kind: "fields", - Type: parent, - FieldNames: make([]string, 0, len(call.Args)-1), - } - for i := 1; i < len(call.Args); i++ { - lit, ok := call.Args[i].(*ast.BasicLit) - if !ok { - return semanticcache.ProviderSetItemArtifact{}, false - } - fieldName, err := strconv.Unquote(lit.Value) - if err != nil { - return semanticcache.ProviderSetItemArtifact{}, false - } - item.FieldNames = append(item.FieldNames, fieldName) - } - return item, true -} - -func summarizeTypeRef(typ types.Type) (semanticcache.TypeRef, bool) { - ref := semanticcache.TypeRef{} - for { - ptr, ok := typ.(*types.Pointer) - if !ok { - break - } - ref.Pointer++ - typ = ptr.Elem() - } - named, ok := typ.(*types.Named) - if !ok { - return semanticcache.TypeRef{}, false - } - obj := named.Obj() - if obj == nil || obj.Pkg() == nil { - return semanticcache.TypeRef{}, false - } - ref.ImportPath = obj.Pkg().Path() - ref.Name = obj.Name() - return ref, true -} - -func semanticVarDecl(pkg *packages.Package, obj *types.Var) *ast.ValueSpec { - if pkg == nil { - return nil - } - return valueSpecForVar(pkg.Fset, pkg.Syntax, obj) -} - // varDecl finds the declaration that defines the given variable. func (oc *objectCache) varDecl(obj *types.Var) *ast.ValueSpec { // TODO(light): Walk files to build object -> declaration mapping, if more performant. diff --git a/internal/wire/parse_coverage_test.go b/internal/wire/parse_coverage_test.go index 4ba2a30..7c7a3b7 100644 --- a/internal/wire/parse_coverage_test.go +++ b/internal/wire/parse_coverage_test.go @@ -22,8 +22,6 @@ import ( "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" - - "github.com/goforj/wire/internal/semanticcache" ) func TestFindInjectorBuildVariants(t *testing.T) { @@ -222,238 +220,6 @@ func TestProcessStructProviderDuplicateFields(t *testing.T) { } } -func TestSummarizeSemanticProviderSet(t *testing.T) { - t.Parallel() - - info := &types.Info{ - Uses: make(map[*ast.Ident]types.Object), - } - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireIdent := ast.NewIdent("wire") - newSetIdent := ast.NewIdent("NewSet") - info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) - info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) - - depPkg := types.NewPackage("example.com/dep", "dep") - fnIdent := ast.NewIdent("NewMessage") - info.Uses[fnIdent] = types.NewFunc(token.NoPos, depPkg, "NewMessage", types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depPkg, "", types.Typ[types.String])), false)) - - call := &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - fnIdent, - }, - } - got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") - if !ok { - t.Fatal("summarizeSemanticProviderSet() = unsupported, want supported") - } - if len(got.Items) != 1 { - t.Fatalf("items len = %d, want 1", len(got.Items)) - } - if got.Items[0].Kind != "func" || got.Items[0].ImportPath != "example.com/dep" || got.Items[0].Name != "NewMessage" { - t.Fatalf("unexpected item: %+v", got.Items[0]) - } -} - -func TestSummarizeSemanticProviderSetTypeOnlyForms(t *testing.T) { - t.Parallel() - - info := &types.Info{ - Types: make(map[ast.Expr]types.TypeAndValue), - Uses: make(map[*ast.Ident]types.Object), - } - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireIdent := ast.NewIdent("wire") - info.Uses[wireIdent] = types.NewPkgName(token.NoPos, nil, "wire", wirePkg) - - appPkg := types.NewPackage("example.com/app", "app") - fooObj := types.NewTypeName(token.NoPos, appPkg, "Foo", nil) - fooNamed := types.NewNamed(fooObj, types.NewStruct([]*types.Var{ - types.NewVar(token.NoPos, appPkg, "Message", types.Typ[types.String]), - }, []string{""}), nil) - fooIfaceObj := types.NewTypeName(token.NoPos, appPkg, "Fooer", nil) - fooIface := types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) - - newSetIdent := ast.NewIdent("NewSet") - bindIdent := ast.NewIdent("Bind") - structIdent := ast.NewIdent("Struct") - fieldsIdent := ast.NewIdent("FieldsOf") - info.Uses[newSetIdent] = types.NewFunc(token.NoPos, wirePkg, "NewSet", nil) - info.Uses[bindIdent] = types.NewFunc(token.NoPos, wirePkg, "Bind", nil) - info.Uses[structIdent] = types.NewFunc(token.NoPos, wirePkg, "Struct", nil) - info.Uses[fieldsIdent] = types.NewFunc(token.NoPos, wirePkg, "FieldsOf", nil) - - newFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Foo")}} - info.Types[newFooCall] = types.TypeAndValue{Type: types.NewPointer(fooNamed)} - newFooIfaceCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("Fooer")}} - info.Types[newFooIfaceCall] = types.TypeAndValue{Type: types.NewPointer(fooIface)} - ptrToPtrFooCall := &ast.CallExpr{Fun: ast.NewIdent("new"), Args: []ast.Expr{ast.NewIdent("FooPtr")}} - info.Types[ptrToPtrFooCall] = types.TypeAndValue{Type: types.NewPointer(types.NewPointer(fooNamed))} - - call := &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: bindIdent}, Args: []ast.Expr{newFooIfaceCall, newFooCall}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, - }, - } - if got, ok := summarizeSemanticProviderSet(info, call, "example.com/app"); ok || len(got.Items) != 0 { - t.Fatalf("summarizeSemanticProviderSet(bind case) = (%+v, %v), want unsupported", got, ok) - } - - call = &ast.CallExpr{ - Fun: &ast.SelectorExpr{X: wireIdent, Sel: newSetIdent}, - Args: []ast.Expr{ - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: structIdent}, Args: []ast.Expr{newFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"*\""}}}, - &ast.CallExpr{Fun: &ast.SelectorExpr{X: wireIdent, Sel: fieldsIdent}, Args: []ast.Expr{ptrToPtrFooCall, &ast.BasicLit{Kind: token.STRING, Value: "\"Message\""}}}, - }, - } - got, ok := summarizeSemanticProviderSet(info, call, "example.com/app") - if !ok { - t.Fatal("summarizeSemanticProviderSet(non-bind type-only forms) = unsupported, want supported") - } - if len(got.Items) != 2 { - t.Fatalf("items len = %d, want 2", len(got.Items)) - } - if got.Items[0].Kind != "struct" || got.Items[1].Kind != "fields" { - t.Fatalf("unexpected kinds: %+v", got.Items) - } -} - -func TestObjectCacheSemanticProviderSetFallback(t *testing.T) { - t.Parallel() - - fset := token.NewFileSet() - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) - wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) - - depTypes := types.NewPackage("example.com/dep", "dep") - msgFnSig := types.NewSignatureType(nil, nil, nil, nil, types.NewTuple(types.NewVar(token.NoPos, depTypes, "", types.Typ[types.String])), false) - msgFn := types.NewFunc(token.NoPos, depTypes, "NewMessage", msgFnSig) - setVar := types.NewVar(token.NoPos, depTypes, "Set", wireNamed) - depTypes.Scope().Insert(msgFn) - depTypes.Scope().Insert(setVar) - - depPkg := &packages.Package{ - Name: "dep", - PkgPath: depTypes.Path(), - Types: depTypes, - Fset: fset, - Imports: make(map[string]*packages.Package), - } - oc := &objectCache{ - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - fset: fset, - packages: map[string]*packages.Package{depPkg.PkgPath: depPkg}, - objects: make(map[objRef]objCacheEntry), - semantic: map[string]*semanticcache.PackageArtifact{ - depPkg.PkgPath: { - Version: 1, - PackagePath: depPkg.PkgPath, - PackageName: depPkg.Name, - Supported: true, - Vars: map[string]semanticcache.ProviderSetArtifact{ - "Set": { - Items: []semanticcache.ProviderSetItemArtifact{ - {Kind: "func", ImportPath: depPkg.PkgPath, Name: "NewMessage"}, - }, - }, - }, - }, - }, - hasher: typeutil.MakeHasher(), - } - item, errs := oc.get(setVar) - if len(errs) > 0 { - t.Fatalf("oc.get(Set) errs = %v", errs) - } - pset, ok := item.(*ProviderSet) - if !ok || pset == nil { - t.Fatalf("oc.get(Set) type = %T, want *ProviderSet", item) - } - if len(pset.Providers) != 1 || pset.Providers[0].Name != "NewMessage" { - t.Fatalf("unexpected providers: %+v", pset.Providers) - } -} - -func TestObjectCacheSemanticProviderSetSkipsBindArtifacts(t *testing.T) { - t.Parallel() - - fset := token.NewFileSet() - wirePkg := types.NewPackage("github.com/goforj/wire", "wire") - wireObj := types.NewTypeName(token.NoPos, wirePkg, "ProviderSet", nil) - wireNamed := types.NewNamed(wireObj, types.NewStruct(nil, nil), nil) - - appTypes := types.NewPackage("example.com/app", "app") - fooIfaceObj := types.NewTypeName(token.NoPos, appTypes, "Fooer", nil) - _ = types.NewNamed(fooIfaceObj, types.NewInterfaceType(nil, nil).Complete(), nil) - fooObj := types.NewTypeName(token.NoPos, appTypes, "Foo", nil) - _ = types.NewNamed(fooObj, types.NewStruct([]*types.Var{ - types.NewVar(token.NoPos, appTypes, "Message", types.Typ[types.String]), - }, []string{""}), nil) - setVar := types.NewVar(token.NoPos, appTypes, "Set", wireNamed) - appTypes.Scope().Insert(fooIfaceObj) - appTypes.Scope().Insert(fooObj) - appTypes.Scope().Insert(setVar) - - appPkg := &packages.Package{ - Name: "app", - PkgPath: appTypes.Path(), - Types: appTypes, - Fset: fset, - Imports: make(map[string]*packages.Package), - } - oc := &objectCache{ - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - fset: fset, - packages: map[string]*packages.Package{appPkg.PkgPath: appPkg}, - objects: make(map[objRef]objCacheEntry), - semantic: map[string]*semanticcache.PackageArtifact{ - appPkg.PkgPath: { - Version: 1, - PackagePath: appPkg.PkgPath, - PackageName: appPkg.Name, - Supported: true, - Vars: map[string]semanticcache.ProviderSetArtifact{ - "Set": { - Items: []semanticcache.ProviderSetItemArtifact{ - { - Kind: "bind", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Fooer"}, - Type2: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, - }, - { - Kind: "struct", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo"}, - AllFields: true, - }, - { - Kind: "fields", - Type: semanticcache.TypeRef{ImportPath: appPkg.PkgPath, Name: "Foo", Pointer: 2}, - FieldNames: []string{"Message"}, - }, - }, - }, - }, - }, - }, - hasher: typeutil.MakeHasher(), - } - pset, ok, errs := oc.semanticProviderSet(setVar) - if len(errs) > 0 { - t.Fatalf("semanticProviderSet(Set) errs = %v", errs) - } - if ok { - t.Fatalf("semanticProviderSet(Set) ok = true, want false") - } - if pset != nil { - t.Fatalf("semanticProviderSet(Set) = %#v, want nil", pset) - } -} - func TestProcessFuncProviderErrors(t *testing.T) { t.Parallel() diff --git a/internal/wire/semantic_reconstruction_test.go b/internal/wire/semantic_reconstruction_test.go deleted file mode 100644 index 79cbe3c..0000000 --- a/internal/wire/semantic_reconstruction_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package wire - -import "testing" - -func TestSemanticReconstructionEnabled(t *testing.T) { - tests := []struct { - name string - env []string - want bool - }{ - { - name: "disabled by default", - want: false, - }, - { - name: "enabled by env", - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=1"}, - want: true, - }, - { - name: "disabled by env", - env: []string{"WIRE_SEMANTIC_RECONSTRUCTION=0"}, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := semanticReconstructionEnabled(tt.env); got != tt.want { - t.Fatalf("semanticReconstructionEnabled(%v) = %v, want %v", tt.env, got, tt.want) - } - }) - } -} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 2459723..1c44eba 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -125,7 +125,7 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o } generated[i].OutputPath = filepath.Join(outDir, opts.PrefixOutputFile+"wire_gen.go") g := newGen(pkg) - oc := newObjectCacheWithEnv([]*packages.Package{pkg}, env) + oc := newObjectCache([]*packages.Package{pkg}) injectorStart := time.Now() injectorFiles, genErrs := generateInjectors(oc, g, pkg) logTiming(ctx, "generate.package."+pkg.PkgPath+".injectors", injectorStart) From 4af9edcf37168c3a130fc29e09e06cc62d48a70b Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 23:44:22 -0500 Subject: [PATCH 68/82] refactor: remove semantic cache layer --- cmd/wire/cache_cmd.go | 1 - cmd/wire/cache_cmd_test.go | 20 ++--- internal/loader/custom.go | 22 ----- internal/semanticcache/cache.go | 143 -------------------------------- 4 files changed, 8 insertions(+), 178 deletions(-) delete mode 100644 internal/semanticcache/cache.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index cdbbd40..1bc4560 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -152,7 +152,6 @@ func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { targets := []cacheTarget{ {name: "loader-artifacts", path: envValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, {name: "discovery-cache", path: filepath.Join(baseWire, "discovery-cache")}, - {name: "semantic-artifacts", path: envValueDefault(env, semanticCacheDirEnv, filepath.Join(baseWire, "semantic-artifacts"))}, {name: "output-cache", path: envValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, } seen := make(map[string]bool, len(targets)) diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go index 83924e2..578c2aa 100644 --- a/cmd/wire/cache_cmd_test.go +++ b/cmd/wire/cache_cmd_test.go @@ -10,10 +10,9 @@ func TestWireCacheTargetsDefault(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") got := wireCacheTargets(nil, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), - "output-cache": filepath.Join(base, "wire", "output-cache"), - "semantic-artifacts": filepath.Join(base, "wire", "semantic-artifacts"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "wire", "loader-artifacts"), + "output-cache": filepath.Join(base, "wire", "output-cache"), } if len(got) != len(want) { t.Fatalf("targets len = %d, want %d", len(got), len(want)) @@ -46,14 +45,12 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { env := []string{ loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), outputCacheDirEnv + "=" + filepath.Join(base, "output"), - semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), } got := wireCacheTargets(env, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), - "loader-artifacts": filepath.Join(base, "loader"), - "output-cache": filepath.Join(base, "output"), - "semantic-artifacts": filepath.Join(base, "semantic"), + "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "loader-artifacts": filepath.Join(base, "loader"), + "output-cache": filepath.Join(base, "output"), } for _, target := range got { if target.path != want[target.name] { @@ -67,7 +64,6 @@ func TestClearWireCachesRemovesTargets(t *testing.T) { env := []string{ loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), outputCacheDirEnv + "=" + filepath.Join(base, "output"), - semanticCacheDirEnv + "=" + filepath.Join(base, "semantic"), } for _, target := range wireCacheTargets(env, base) { if err := os.MkdirAll(target.path, 0o755); err != nil { @@ -85,8 +81,8 @@ func TestClearWireCachesRemovesTargets(t *testing.T) { if err != nil { t.Fatalf("clearWireCaches() error = %v", err) } - if len(cleared) != 4 { - t.Fatalf("cleared len = %d, want 4", len(cleared)) + if len(cleared) != 3 { + t.Fatalf("cleared len = %d, want 3", len(cleared)) } for _, target := range wireCacheTargets(env, base) { if _, err := os.Stat(target.path); !os.IsNotExist(err) { diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 007eebd..6fa586b 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -34,8 +34,6 @@ import ( "golang.org/x/tools/go/gcexportdata" "golang.org/x/tools/go/packages" - - "github.com/goforj/wire/internal/semanticcache" ) type unsupportedError struct { @@ -93,7 +91,6 @@ type customTypedGraphLoader struct { importer types.Importer loading map[string]bool isLocalCache map[string]bool - localSemanticOK map[string]bool artifactPrefetch map[string]artifactPrefetchEntry stats typedLoadStats } @@ -530,22 +527,6 @@ func (l *customTypedGraphLoader) loadPackage(path string) (*packages.Package, er return pkg, nil } -func (l *customTypedGraphLoader) localSemanticArtifactSupported(meta *packageMeta) bool { - if meta == nil { - return false - } - if ok, exists := l.localSemanticOK[meta.ImportPath]; exists { - return ok - } - art, err := semanticcache.Read(l.env, meta.ImportPath, meta.Name, metaFiles(meta)) - if err != nil || art == nil { - l.localSemanticOK[meta.ImportPath] = false - return false - } - l.localSemanticOK[meta.ImportPath] = art.Supported - return art.Supported -} - func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isLocal bool) artifactPolicy { if !loaderArtifactEnabled(l.env) || isTarget { return artifactPolicy{} @@ -553,9 +534,7 @@ func (l *customTypedGraphLoader) artifactPolicy(meta *packageMeta, isTarget, isL policy := artifactPolicy{write: true} if !isLocal { policy.read = true - return policy } - policy.read = l.localSemanticArtifactSupported(meta) return policy } @@ -1185,7 +1164,6 @@ func newCustomTypedGraphLoader(ctx context.Context, wd string, env []string, fse importer: importerpkg.ForCompiler(token.NewFileSet(), "gc", nil), loading: make(map[string]bool, len(meta)), isLocalCache: make(map[string]bool, len(meta)), - localSemanticOK: make(map[string]bool, len(meta)), artifactPrefetch: make(map[string]artifactPrefetchEntry, len(meta)), stats: typedLoadStats{discovery: discoveryDuration}, } diff --git a/internal/semanticcache/cache.go b/internal/semanticcache/cache.go deleted file mode 100644 index 4442415..0000000 --- a/internal/semanticcache/cache.go +++ /dev/null @@ -1,143 +0,0 @@ -package semanticcache - -import ( - "crypto/sha256" - "encoding/gob" - "encoding/hex" - "os" - "path/filepath" - "runtime" - "strconv" -) - -const dirEnv = "WIRE_SEMANTIC_CACHE_DIR" - -type PackageArtifact struct { - Version int - PackagePath string - PackageName string - HasProviderSetVars bool - Supported bool - Vars map[string]ProviderSetArtifact -} - -type ProviderSetArtifact struct { - Items []ProviderSetItemArtifact -} - -type ProviderSetItemArtifact struct { - Kind string - ImportPath string - Name string - Type TypeRef - Type2 TypeRef - FieldNames []string - AllFields bool -} - -type TypeRef struct { - ImportPath string - Name string - Pointer int -} - -func ArtifactPath(env []string, importPath, packageName string, files []string) (string, error) { - dir, err := artifactDir(env) - if err != nil { - return "", err - } - key, err := artifactKey(importPath, packageName, files) - if err != nil { - return "", err - } - return filepath.Join(dir, key+".gob"), nil -} - -func Read(env []string, importPath, packageName string, files []string) (*PackageArtifact, error) { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return nil, err - } - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - var art PackageArtifact - if err := gob.NewDecoder(f).Decode(&art); err != nil { - return nil, err - } - return &art, nil -} - -func Write(env []string, importPath, packageName string, files []string, art *PackageArtifact) error { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return err - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - f, err := os.Create(path) - if err != nil { - return err - } - defer f.Close() - return gob.NewEncoder(f).Encode(art) -} - -func Exists(env []string, importPath, packageName string, files []string) bool { - path, err := ArtifactPath(env, importPath, packageName, files) - if err != nil { - return false - } - _, err = os.Stat(path) - return err == nil -} - -func artifactDir(env []string) (string, error) { - for i := len(env) - 1; i >= 0; i-- { - key, val, ok := splitEnv(env[i]) - if ok && key == dirEnv && val != "" { - return val, nil - } - } - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - return filepath.Join(base, "wire", "semantic-artifacts"), nil -} - -func artifactKey(importPath, packageName string, files []string) (string, error) { - sum := sha256.New() - sum.Write([]byte("wire-semantic-artifact-v1\n")) - sum.Write([]byte(runtime.Version())) - sum.Write([]byte{'\n'}) - sum.Write([]byte(importPath)) - sum.Write([]byte{'\n'}) - sum.Write([]byte(packageName)) - sum.Write([]byte{'\n'}) - for _, name := range files { - info, err := os.Stat(name) - if err != nil { - return "", err - } - sum.Write([]byte(name)) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) - sum.Write([]byte{'\n'}) - } - return hex.EncodeToString(sum.Sum(nil)), nil -} - -func splitEnv(kv string) (string, string, bool) { - for i := 0; i < len(kv); i++ { - if kv[i] == '=' { - return kv[:i], kv[i+1:], true - } - } - return "", "", false -} From bf14bb5e5ea652e0d06ee6c632613074806bdf81 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Mon, 16 Mar 2026 23:57:34 -0500 Subject: [PATCH 69/82] refactor: add import benchmark profile filter --- internal/wire/import_bench_test.go | 13 +++++++++++++ scripts/import-benchmarks.sh | 9 +++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index cd38190..ad2af50 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -21,6 +21,7 @@ const ( importBenchBreakdown = "WIRE_IMPORT_BENCH_BREAKDOWN" importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" + importBenchProfile = "WIRE_IMPORT_BENCH_PROFILE" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -145,6 +146,18 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, } + if filter := os.Getenv(importBenchProfile); filter != "" { + filtered := make([]appBenchProfile, 0, len(profiles)) + for _, profile := range profiles { + if profile.label == filter { + filtered = append(filtered, profile) + } + } + if len(filtered) == 0 { + t.Fatalf("%s=%q did not match any benchmark profile", importBenchProfile, filter) + } + profiles = filtered + } rows := make([]importBenchScenarioRow, 0, len(profiles)*6) for _, profile := range profiles { shapeFixture := createAppShapeBenchFixture(t, profile.localPkgs, profile.depPkgs, profile.external, currentWireModulePath, repoRoot) diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index 2eb98c2..6c1ca47 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -12,12 +12,13 @@ usage() { cat <<'EOF' Usage: scripts/import-benchmarks.sh table - scripts/import-benchmarks.sh scenarios + scripts/import-benchmarks.sh scenarios [profile] scripts/import-benchmarks.sh breakdown Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. scenarios Print the stock-vs-current change-type scenario table. + Optional profiles: local, local-high, external. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } @@ -27,7 +28,11 @@ case "${1:-}" in WIRE_IMPORT_BENCH_TABLE=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkTable -count=1 -v ;; scenarios) - WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + if [[ -n "${2:-}" ]]; then + WIRE_IMPORT_BENCH_SCENARIOS=1 WIRE_IMPORT_BENCH_PROFILE="${2}" go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + else + WIRE_IMPORT_BENCH_SCENARIOS=1 go test ./internal/wire -run TestPrintImportScenarioBenchmarkTable -count=1 -v + fi ;; breakdown) WIRE_IMPORT_BENCH_BREAKDOWN=1 go test ./internal/wire -run TestPrintImportScaleBenchmarkBreakdown -count=1 -v From bf4a02deb2bb7c342a7d3ee02bd2759e18cb32c9 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:05:28 -0500 Subject: [PATCH 70/82] refactor: trim redundant discovery cache metadata --- internal/loader/discovery_cache.go | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 9d7d932..d891d01 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -15,11 +15,6 @@ import ( type discoveryCacheEntry struct { Version int - WD string - Tags string - Patterns []string - NeedDeps bool - Workspace string Meta map[string]*packageMeta Global []discoveryFileMeta LocalPkgs []discoveryLocalPackage @@ -49,6 +44,8 @@ type discoveryFileFingerprint struct { Hash string } +const discoveryCacheVersion = 3 + func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { entry, err := loadDiscoveryCacheEntry(req) if err != nil || entry == nil { @@ -63,13 +60,8 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ - Version: 3, - WD: canonicalLoaderPath(req.WD), - Tags: req.Tags, - Patterns: append([]string(nil), req.Patterns...), - NeedDeps: req.NeedDeps, - Workspace: workspace, - Meta: clonePackageMetaMap(meta), + Version: discoveryCacheVersion, + Meta: clonePackageMetaMap(meta), } global := []string{ filepath.Join(workspace, "go.mod"), @@ -108,7 +100,7 @@ func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) ( } func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { - if entry == nil || entry.Version != 3 { + if entry == nil || entry.Version != discoveryCacheVersion { return false } for _, fm := range entry.Global { @@ -142,7 +134,7 @@ func discoveryCachePath(req goListRequest) (string, error) { NeedDeps bool Go string }{ - Version: 3, + Version: discoveryCacheVersion, WD: canonicalLoaderPath(req.WD), Tags: req.Tags, Patterns: append([]string(nil), req.Patterns...), From 788798d1b127311871a4b467a99b8c4db16e6f90 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:06:38 -0500 Subject: [PATCH 71/82] refactor: remove redundant discovery cache cloning --- internal/loader/discovery_cache.go | 46 ++---------------------------- 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index d891d01..52a67c9 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -54,14 +54,14 @@ func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { if !validateDiscoveryCacheEntry(entry) { return nil, false } - return clonePackageMetaMap(entry.Meta), true + return entry.Meta, true } func buildDiscoveryCacheEntry(req goListRequest, meta map[string]*packageMeta) (*discoveryCacheEntry, error) { workspace := detectModuleRoot(req.WD) entry := &discoveryCacheEntry{ Version: discoveryCacheVersion, - Meta: clonePackageMetaMap(meta), + Meta: meta, } global := []string{ filepath.Join(workspace, "go.mod"), @@ -285,45 +285,3 @@ func hashGob(v interface{}) (string, error) { sum := sha256.Sum256(buf.Bytes()) return hex.EncodeToString(sum[:]), nil } - -func clonePackageMetaMap(in map[string]*packageMeta) map[string]*packageMeta { - if len(in) == 0 { - return nil - } - out := make(map[string]*packageMeta, len(in)) - for k, v := range in { - if v == nil { - continue - } - cp := *v - cp.GoFiles = append([]string(nil), v.GoFiles...) - cp.CompiledGoFiles = append([]string(nil), v.CompiledGoFiles...) - cp.Imports = append([]string(nil), v.Imports...) - if v.ImportMap != nil { - cp.ImportMap = make(map[string]string, len(v.ImportMap)) - for mk, mv := range v.ImportMap { - cp.ImportMap[mk] = mv - } - } - if v.Module != nil { - cp.Module = cloneGoListModule(v.Module) - } - if v.Error != nil { - errCopy := *v.Error - cp.Error = &errCopy - } - out[k] = &cp - } - return out -} - -func cloneGoListModule(in *goListModule) *goListModule { - if in == nil { - return nil - } - cp := *in - if in.Replace != nil { - cp.Replace = cloneGoListModule(in.Replace) - } - return &cp -} From 0921fc2e6eb00b0b4d138d7f7608f3eceec08fbb Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:11:07 -0500 Subject: [PATCH 72/82] refactor: split external benchmark profiles --- internal/wire/import_bench_test.go | 4 ++-- scripts/import-benchmarks.sh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index ad2af50..d3f8ed7 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -143,8 +143,8 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { profiles := []appBenchProfile{ {localPkgs: 10, depPkgs: 25, label: "local"}, {localPkgs: 10, depPkgs: 1000, label: "local-high"}, - {localPkgs: 10, depPkgs: 25, external: true, label: "external"}, - {localPkgs: 10, depPkgs: 100, external: true, label: "external"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, } if filter := os.Getenv(importBenchProfile); filter != "" { filtered := make([]appBenchProfile, 0, len(profiles)) diff --git a/scripts/import-benchmarks.sh b/scripts/import-benchmarks.sh index 6c1ca47..232ccd9 100755 --- a/scripts/import-benchmarks.sh +++ b/scripts/import-benchmarks.sh @@ -18,7 +18,7 @@ Usage: Commands: table Print the 10/100/1000 import stock-vs-current benchmark table. scenarios Print the stock-vs-current change-type scenario table. - Optional profiles: local, local-high, external. + Optional profiles: local, local-high, external-low, external-high. breakdown Print a focused 1000-import cold/unchanged breakdown. EOF } From e52201b03a4817e85fbe39cfb7f392f6a5471c4c Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:18:57 -0500 Subject: [PATCH 73/82] refactor: share custom typed load pipeline --- internal/loader/custom.go | 44 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 6fa586b..76f86ba 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -235,19 +235,10 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } } discoveryDuration := time.Since(discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } - fset := req.Fset - if fset == nil { - fset = token.NewFileSet() - } - l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, map[string]struct{}{req.Package: {}}, req.ParseFile, discoveryDuration) - roots, err := loadCustomRootPackages(l, []string{req.Package}) + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, map[string]struct{}{req.Package: {}}, []string{req.Package}, req.ParseFile, discoveryDuration, "lazy") if err != nil { return nil, err } - logTypedLoadStats(ctx, "lazy", l.stats) return &LazyLoadResult{ Packages: roots, Backend: ModeCustom, @@ -267,33 +258,40 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo return nil, err } discoveryDuration := time.Since(discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } - fset := req.Fset - if fset == nil { - fset = token.NewFileSet() - } rootPaths := nonDepRootImportPaths(meta) targets := make(map[string]struct{}, len(rootPaths)) for _, path := range rootPaths { targets[path] = struct{}{} } - if len(targets) == 0 { - return nil, unsupportedError{reason: "no root packages from metadata"} - } - l := newCustomTypedGraphLoader(ctx, req.WD, req.Env, fset, meta, targets, req.ParseFile, discoveryDuration) - roots, err := loadCustomRootPackages(l, rootPaths) + roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, targets, rootPaths, req.ParseFile, discoveryDuration, "typed") if err != nil { return nil, err } - logTypedLoadStats(ctx, "typed", l.stats) return &PackageLoadResult{ Packages: roots, Backend: ModeCustom, }, nil } +func loadCustomPackagesFromMeta(ctx context.Context, wd string, env []string, fset *token.FileSet, meta map[string]*packageMeta, targets map[string]struct{}, rootPaths []string, parseFile ParseFileFunc, discoveryDuration time.Duration, mode string) ([]*packages.Package, error) { + if len(meta) == 0 { + return nil, unsupportedError{reason: "empty go list result"} + } + if len(rootPaths) == 0 { + return nil, unsupportedError{reason: "no root packages from metadata"} + } + if fset == nil { + fset = token.NewFileSet() + } + l := newCustomTypedGraphLoader(ctx, wd, env, fset, meta, targets, parseFile, discoveryDuration) + roots, err := loadCustomRootPackages(l, rootPaths) + if err != nil { + return nil, err + } + logTypedLoadStats(ctx, mode, l.stats) + return roots, nil +} + func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() From 2dc4ac4e5fd4b75b70a70a9e5ade047278440c83 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:21:41 -0500 Subject: [PATCH 74/82] refactor: centralize custom metadata loading --- internal/loader/custom.go | 61 +++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 76f86ba..fce84f4 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -173,8 +173,7 @@ func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationReq } func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { - discoveryStart := time.Now() - meta, err := runGoList(ctx, goListRequest{ + meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -184,10 +183,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes if err != nil { return nil, err } - logTiming(ctx, "loader.custom.root.discovery", discoveryStart) - if len(meta) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } + logDuration(ctx, "loader.custom.root.discovery", discoveryDuration) pkgs := packageStubGraphFromMeta(nil, meta) rootPaths := nonDepRootImportPaths(meta) roots := make([]*packages.Package, 0, len(rootPaths)) @@ -219,22 +215,10 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz meta map[string]*packageMeta err error ) - discoveryStart := time.Now() - if req.Discovery != nil && len(req.Discovery.meta) > 0 { - meta = req.Discovery.meta - } else { - meta, err = runGoList(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: []string{req.Package}, - NeedDeps: true, - }) - if err != nil { - return nil, err - } + meta, discoveryDuration, err := loadCustomLazyMeta(ctx, req) + if err != nil { + return nil, err } - discoveryDuration := time.Since(discoveryStart) roots, err := loadCustomPackagesFromMeta(ctx, req.WD, req.Env, req.Fset, meta, map[string]struct{}{req.Package: {}}, []string{req.Package}, req.ParseFile, discoveryDuration, "lazy") if err != nil { return nil, err @@ -246,8 +230,7 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { - discoveryStart := time.Now() - meta, err := runGoList(ctx, goListRequest{ + meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -257,7 +240,6 @@ func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLo if err != nil { return nil, err } - discoveryDuration := time.Since(discoveryStart) rootPaths := nonDepRootImportPaths(meta) targets := make(map[string]struct{}, len(rootPaths)) for _, path := range rootPaths { @@ -292,6 +274,32 @@ func loadCustomPackagesFromMeta(ctx context.Context, wd string, env []string, fs return roots, nil } +func loadCustomMeta(ctx context.Context, req goListRequest) (map[string]*packageMeta, time.Duration, error) { + start := time.Now() + meta, err := runGoList(ctx, req) + duration := time.Since(start) + if err != nil { + return nil, duration, err + } + if len(meta) == 0 { + return nil, duration, unsupportedError{reason: "empty go list result"} + } + return meta, duration, nil +} + +func loadCustomLazyMeta(ctx context.Context, req LazyLoadRequest) (map[string]*packageMeta, time.Duration, error) { + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + return req.Discovery.meta, 0, nil + } + return loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: []string{req.Package}, + NeedDeps: true, + }) +} + func loadCustomRootPackages(l *customTypedGraphLoader, paths []string) ([]*packages.Package, error) { prefetchStart := time.Now() l.prefetchArtifacts() @@ -994,7 +1002,7 @@ func importName(spec *ast.ImportSpec) string { } func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) (map[string]*packageMeta, error) { - metas, err := runGoList(ctx, goListRequest{ + metas, _, err := loadCustomMeta(ctx, goListRequest{ WD: req.WD, Env: req.Env, Tags: req.Tags, @@ -1004,9 +1012,6 @@ func discoverTouchedMetadata(ctx context.Context, req TouchedValidationRequest) if err != nil { return nil, err } - if len(metas) == 0 { - return nil, unsupportedError{reason: "empty go list result"} - } for _, touched := range req.Touched { if _, ok := metas[touched]; !ok { return nil, unsupportedError{reason: "missing touched package in metadata"} From 6f0966483067e91a987ba1f5f8b3df786fbdf8c3 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:25:39 -0500 Subject: [PATCH 75/82] refactor: share loader fallback reason policy --- internal/loader/fallback.go | 47 ++++++++++--------------------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/internal/loader/fallback.go b/internal/loader/fallback.go index 513694c..860bd50 100644 --- a/internal/loader/fallback.go +++ b/internal/loader/fallback.go @@ -24,6 +24,15 @@ import ( type defaultLoader struct{} +func fallbackReasonDetail(mode Mode, detail string) (FallbackReason, string) { + switch mode { + case ModeFallback: + return FallbackReasonForcedFallback, "" + default: + return FallbackReasonCustomUnsupported, detail + } +} + func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { var unsupported unsupportedError if req.LoaderMode != ModeFallback { @@ -38,15 +47,7 @@ func (defaultLoader) LoadPackages(ctx context.Context, req PackageLoadRequest) ( result := &PackageLoadResult{ Backend: ModeFallback, } - switch req.LoaderMode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: req.Mode, @@ -90,15 +91,7 @@ func (defaultLoader) LoadRootGraph(ctx context.Context, req RootLoadRequest) (*R result := &RootLoadResult{ Backend: ModeFallback, } - switch req.Mode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: packages.NeedName | packages.NeedFiles | packages.NeedCompiledGoFiles | packages.NeedImports, @@ -142,15 +135,7 @@ func (defaultLoader) LoadTypedPackageGraph(ctx context.Context, req LazyLoadRequ result := &LazyLoadResult{ Backend: ModeFallback, } - switch req.LoaderMode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - if unsupported.reason != "" { - result.FallbackDetail = unsupported.reason - } - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.LoaderMode, unsupported.reason) cfg := &packages.Config{ Context: ctx, Mode: req.Mode, @@ -194,13 +179,7 @@ func validateTouchedPackagesFallback(ctx context.Context, req TouchedValidationR result := &TouchedValidationResult{ Backend: ModeFallback, } - switch req.Mode { - case ModeFallback: - result.FallbackReason = FallbackReasonForcedFallback - default: - result.FallbackReason = FallbackReasonCustomUnsupported - result.FallbackDetail = detail - } + result.FallbackReason, result.FallbackDetail = fallbackReasonDetail(req.Mode, detail) if len(req.Touched) == 0 { return result, nil } From d003c9f88f447f922979a08406d6d67b2a1b31ed Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:45:14 -0500 Subject: [PATCH 76/82] refactor: add targeted local profile benchmark --- internal/wire/import_bench_test.go | 80 ++++++++++++++++++++++++------ 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index d3f8ed7..24c58da 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -266,6 +266,25 @@ func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { printScenarioTimingLines(currentOutput) } +func BenchmarkCurrentWireLocalProfile(b *testing.B) { + repoRoot, err := importBenchRepoRoot() + if err != nil { + b.Fatal(err) + } + currentBin := buildWireBinary(b, repoRoot, "current-wire") + const ( + features = 10 + depPkgs = 25 + external = false + ) + + for _, variant := range []string{"unchanged", "body", "shape", "import", "known-toggle"} { + b.Run(variant, func(b *testing.B) { + benchmarkCurrentWireAppScenario(b, currentBin, repoRoot, features, depPkgs, external, variant) + }) + } +} + func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) @@ -278,6 +297,35 @@ func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external return durations } +func benchmarkCurrentWireAppScenario(b *testing.B, bin, repoRoot string, features, depPkgs int, external bool, variant string) { + b.Helper() + pkgDir := createAppShapeBenchFixture(b, features, depPkgs, external, currentWireModulePath, repoRoot) + caches := newBenchCaches(b) + root := filepath.Dir(pkgDir) + prewarmGoBenchCache(b, pkgDir, caches) + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + resetAppShapeBenchFixture(b, pkgDir, features) + switch variant { + case "body", "shape", "import": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, variant) + case "known-toggle": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "shape") + _ = runWireBenchCommand(b, bin, pkgDir, caches) + writeAppShapeControllerFile(b, root, 0, "base") + case "unchanged": + _ = runWireBenchCommand(b, bin, pkgDir, caches) + default: + b.Fatalf("unknown benchmark variant %q", variant) + } + b.StartTimer() + _ = runWireBenchCommand(b, bin, pkgDir, caches) + } +} + func runAppWarmTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) @@ -326,7 +374,7 @@ func runAppKnownToggleTrials(t *testing.T, bin string, features, depPkgs int, ex return durations } -func buildWireBinary(t *testing.T, dir, name string) string { +func buildWireBinary(t testing.TB, dir, name string) string { t.Helper() if runtime.GOOS == "windows" && filepath.Ext(name) != ".exe" { name += ".exe" @@ -342,7 +390,7 @@ func buildWireBinary(t *testing.T, dir, name string) string { return out } -func newBenchCaches(t *testing.T) benchCaches { +func newBenchCaches(t testing.TB) benchCaches { t.Helper() return benchCaches{ home: t.TempDir(), @@ -350,7 +398,7 @@ func newBenchCaches(t *testing.T) benchCaches { } } -func extractStockWire(t *testing.T, repoRoot, commit string) string { +func extractStockWire(t testing.TB, repoRoot, commit string) string { t.Helper() tmp := t.TempDir() cmd := exec.Command("git", "archive", "--format=tar", commit) @@ -400,7 +448,7 @@ func extractStockWire(t *testing.T, repoRoot, commit string) string { return tmp } -func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireReplaceDir string) string { +func createImportBenchFixture(t testing.TB, imports int, wireModulePath, wireReplaceDir string) string { t.Helper() root := t.TempDir() if err := os.WriteFile(filepath.Join(root, "go.mod"), []byte(importBenchGoMod(wireModulePath, wireReplaceDir)), 0o644); err != nil { @@ -424,7 +472,7 @@ func createImportBenchFixture(t *testing.T, imports int, wireModulePath, wireRep return filepath.Join(root, "app") } -func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { +func createAppShapeBenchFixture(t testing.TB, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string) string { t.Helper() root := t.TempDir() modulePath := "example.com/appbench" @@ -455,7 +503,7 @@ func createAppShapeBenchFixture(t *testing.T, features, depPkgs int, external bo return filepath.Join(root, "wire") } -func writeAppShapeFile(t *testing.T, path, content string) { +func writeAppShapeFile(t testing.TB, path, content string) { t.Helper() if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { t.Fatal(err) @@ -465,7 +513,7 @@ func writeAppShapeFile(t *testing.T, path, content string) { } } -func writeAppShapeControllerFile(t *testing.T, root string, index int, variant string) { +func writeAppShapeControllerFile(t testing.TB, root string, index int, variant string) { t.Helper() path := filepath.Join(root, "internal", fmt.Sprintf("feature%04d", index), "controller.go") if err := os.WriteFile(path, []byte(appShapeControllerFile("example.com/appbench", index, variant)), 0o644); err != nil { @@ -473,7 +521,7 @@ func writeAppShapeControllerFile(t *testing.T, root string, index int, variant s } } -func seedAppShapeExternalGoSum(t *testing.T, root string) { +func seedAppShapeExternalGoSum(t testing.TB, root string) { t.Helper() const source = "/private/tmp/test/go.sum" data, err := os.ReadFile(source) @@ -485,7 +533,7 @@ func seedAppShapeExternalGoSum(t *testing.T, root string) { } } -func resetAppShapeBenchFixture(t *testing.T, pkgDir string, features int) { +func resetAppShapeBenchFixture(t testing.TB, pkgDir string, features int) { t.Helper() root := filepath.Dir(pkgDir) for i := 0; i < features; i++ { @@ -1019,13 +1067,13 @@ func runKnownImportToggleTrials(t *testing.T, bin string, imports int, wireModul return durations } -func runWireBenchCommand(t *testing.T, bin, pkgDir string, caches benchCaches) time.Duration { +func runWireBenchCommand(t testing.TB, bin, pkgDir string, caches benchCaches) time.Duration { t.Helper() d, _ := runWireBenchCommandOutput(t, bin, pkgDir, caches) return d } -func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { +func runWireBenchCommandOutput(t testing.TB, bin, pkgDir string, caches benchCaches, extraArgs ...string) (time.Duration, string) { t.Helper() args := []string{"gen"} args = append(args, extraArgs...) @@ -1042,7 +1090,7 @@ func runWireBenchCommandOutput(t *testing.T, bin, pkgDir string, caches benchCac return time.Since(start), stderr.String() } -func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { +func prewarmGoBenchCache(t testing.TB, pkgDir string, caches benchCaches) { t.Helper() prepareBenchModule(t, pkgDir, caches) cmd := exec.Command("go", "list", "-deps", "./...") @@ -1054,7 +1102,7 @@ func prewarmGoBenchCache(t *testing.T, pkgDir string, caches benchCaches) { } } -func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { +func goListGraphCounts(t testing.TB, pkgDir, modulePath string, caches benchCaches) benchGraphCounts { t.Helper() prepareBenchModule(t, pkgDir, caches) cmd := exec.Command("go", "list", "-deps", "-json", "./...") @@ -1097,7 +1145,7 @@ func goListGraphCounts(t *testing.T, pkgDir, modulePath string, caches benchCach return counts } -func prepareBenchModule(t *testing.T, pkgDir string, caches benchCaches) { +func prepareBenchModule(t testing.TB, pkgDir string, caches benchCaches) { t.Helper() marker := filepath.Join(filepath.Dir(pkgDir), ".bench-module-ready") if _, err := os.Stat(marker); err == nil { @@ -1216,7 +1264,7 @@ func importBenchDepFile(i int, variant string) string { } } -func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModulePath string) { +func writeImportBenchWireFile(t testing.TB, root string, imports int, wireModulePath string) { t.Helper() path := filepath.Join(root, "app", "wire.go") if err := os.WriteFile(path, []byte(importBenchWireFile(imports, wireModulePath)), 0o644); err != nil { @@ -1224,7 +1272,7 @@ func writeImportBenchWireFile(t *testing.T, root string, imports int, wireModule } } -func writeImportBenchDepFile(t *testing.T, root string, index int, variant string) { +func writeImportBenchDepFile(t testing.TB, root string, index int, variant string) { t.Helper() path := filepath.Join(root, fmt.Sprintf("dep%04d", index), "dep.go") if err := os.WriteFile(path, []byte(importBenchDepFile(index, variant)), 0o644); err != nil { From a1125656097b1724625732ec13a268209acb337d Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 00:49:59 -0500 Subject: [PATCH 77/82] refactor: add one-shot import profile harness --- internal/wire/import_bench_test.go | 96 ++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 12 deletions(-) diff --git a/internal/wire/import_bench_test.go b/internal/wire/import_bench_test.go index 24c58da..39dd862 100644 --- a/internal/wire/import_bench_test.go +++ b/internal/wire/import_bench_test.go @@ -22,6 +22,9 @@ const ( importBenchScenarios = "WIRE_IMPORT_BENCH_SCENARIOS" importBenchScenarioBD = "WIRE_IMPORT_BENCH_SCENARIO_BREAKDOWN" importBenchProfile = "WIRE_IMPORT_BENCH_PROFILE" + importBenchProfileRun = "WIRE_IMPORT_BENCH_PROFILE_RUN" + importBenchVariant = "WIRE_IMPORT_BENCH_VARIANT" + importBenchCPUProfile = "WIRE_IMPORT_BENCH_CPU_PROFILE" stockWireCommit = "9c25c9016f6825302537c4efdd5e897976f9c826" stockWireModulePath = "github.com/google/wire" currentWireModulePath = "github.com/goforj/wire" @@ -134,18 +137,7 @@ func TestPrintImportScenarioBenchmarkTable(t *testing.T) { stockDir := extractStockWire(t, repoRoot, stockWireCommit) stockBin := buildWireBinary(t, stockDir, "stock-wire") - type appBenchProfile struct { - localPkgs int - depPkgs int - external bool - label string - } - profiles := []appBenchProfile{ - {localPkgs: 10, depPkgs: 25, label: "local"}, - {localPkgs: 10, depPkgs: 1000, label: "local-high"}, - {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, - {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, - } + profiles := importBenchAppProfiles() if filter := os.Getenv(importBenchProfile); filter != "" { filtered := make([]appBenchProfile, 0, len(profiles)) for _, profile := range profiles { @@ -266,6 +258,61 @@ func TestPrintImportScenarioBenchmarkBreakdown(t *testing.T) { printScenarioTimingLines(currentOutput) } +func TestProfileCurrentWireScenarioRun(t *testing.T) { + if os.Getenv(importBenchProfileRun) != "1" { + t.Skipf("%s not set", importBenchProfileRun) + } + profile := os.Getenv(importBenchProfile) + variant := os.Getenv(importBenchVariant) + cpuProfile := os.Getenv(importBenchCPUProfile) + if profile == "" { + t.Fatalf("%s must be set", importBenchProfile) + } + if variant == "" { + t.Fatalf("%s must be set", importBenchVariant) + } + if cpuProfile == "" { + t.Fatalf("%s must be set", importBenchCPUProfile) + } + repoRoot, err := importBenchRepoRoot() + if err != nil { + t.Fatal(err) + } + currentBin := buildWireBinary(t, repoRoot, "current-wire") + profileCfg, err := importBenchAppProfile(profile) + if err != nil { + t.Fatal(err) + } + pkgDir := createAppShapeBenchFixture(t, profileCfg.localPkgs, profileCfg.depPkgs, profileCfg.external, currentWireModulePath, repoRoot) + caches := newBenchCaches(t) + prewarmGoBenchCache(t, pkgDir, caches) + root := filepath.Dir(pkgDir) + + switch variant { + case "unchanged": + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + case "body", "shape", "import": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, variant) + case "known-toggle": + resetAppShapeBenchFixture(t, pkgDir, profileCfg.localPkgs) + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "shape") + _ = runWireBenchCommand(t, currentBin, pkgDir, caches) + writeAppShapeControllerFile(t, root, 0, "base") + default: + t.Fatalf("unknown %s %q", importBenchVariant, variant) + } + + dur, output := runWireBenchCommandOutput(t, currentBin, pkgDir, caches, "-cpuprofile="+cpuProfile, "-timings") + fmt.Printf("profile: %s\n", profile) + fmt.Printf("variant: %s\n", variant) + fmt.Printf("duration: %s\n", formatMs(dur)) + fmt.Printf("cpuprofile: %s\n", cpuProfile) + printScenarioTimingLines(output) +} + func BenchmarkCurrentWireLocalProfile(b *testing.B) { repoRoot, err := importBenchRepoRoot() if err != nil { @@ -285,6 +332,31 @@ func BenchmarkCurrentWireLocalProfile(b *testing.B) { } } +type appBenchProfile struct { + localPkgs int + depPkgs int + external bool + label string +} + +func importBenchAppProfiles() []appBenchProfile { + return []appBenchProfile{ + {localPkgs: 10, depPkgs: 25, label: "local"}, + {localPkgs: 10, depPkgs: 1000, label: "local-high"}, + {localPkgs: 10, depPkgs: 25, external: true, label: "external-low"}, + {localPkgs: 10, depPkgs: 100, external: true, label: "external-high"}, + } +} + +func importBenchAppProfile(label string) (appBenchProfile, error) { + for _, profile := range importBenchAppProfiles() { + if profile.label == label { + return profile, nil + } + } + return appBenchProfile{}, fmt.Errorf("%s=%q did not match any benchmark profile", importBenchProfile, label) +} + func runAppColdTrials(t *testing.T, bin string, features, depPkgs int, external bool, wireModulePath, wireReplaceDir string, trials int) []time.Duration { t.Helper() durations := make([]time.Duration, 0, trials) From ee3ffc9890c2db874dce6e0c116d7f328c075247 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 01:56:00 -0500 Subject: [PATCH 78/82] perf: reuse root discovery for generate loads --- internal/loader/custom.go | 24 +++++++++++++++++------- internal/loader/discovery.go | 6 +++++- internal/loader/discovery_cache.go | 14 ++++++++------ internal/loader/loader.go | 1 + internal/wire/output_cache.go | 16 ++++++++-------- internal/wire/parse.go | 5 +++-- internal/wire/wire.go | 4 ++-- 7 files changed, 44 insertions(+), 26 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index fce84f4..112fe85 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -179,6 +179,7 @@ func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadRes Tags: req.Tags, Patterns: req.Patterns, NeedDeps: req.NeedDeps, + SkipCompiled: true, }) if err != nil { return nil, err @@ -230,13 +231,22 @@ func loadTypedPackageGraphCustom(ctx context.Context, req LazyLoadRequest) (*Laz } func loadPackagesCustom(ctx context.Context, req PackageLoadRequest) (*PackageLoadResult, error) { - meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: req.Patterns, - NeedDeps: true, - }) + var ( + meta map[string]*packageMeta + discoveryDuration time.Duration + err error + ) + if req.Discovery != nil && len(req.Discovery.meta) > 0 { + meta = req.Discovery.meta + } else { + meta, discoveryDuration, err = loadCustomMeta(ctx, goListRequest{ + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: true, + }) + } if err != nil { return nil, err } diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index 0e7e69c..e416e95 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -32,6 +32,7 @@ type goListRequest struct { Tags string Patterns []string NeedDeps bool + SkipCompiled bool } func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, error) { @@ -46,7 +47,10 @@ func runGoList(ctx context.Context, req goListRequest) (map[string]*packageMeta, return cached, nil } logDuration(ctx, "loader.discovery.cache_read.wall", time.Since(cacheReadStart)) - args := []string{"list", "-json", "-e", "-compiled", "-export"} + args := []string{"list", "-json", "-e", "-export"} + if !req.SkipCompiled { + args = append(args, "-compiled") + } if req.NeedDeps { args = append(args, "-deps") } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 52a67c9..1a93bdf 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -132,14 +132,16 @@ func discoveryCachePath(req goListRequest) (string, error) { Tags string Patterns []string NeedDeps bool + SkipCompiled bool Go string }{ - Version: discoveryCacheVersion, - WD: canonicalLoaderPath(req.WD), - Tags: req.Tags, - Patterns: append([]string(nil), req.Patterns...), - NeedDeps: req.NeedDeps, - Go: runtime.Version(), + Version: discoveryCacheVersion, + WD: canonicalLoaderPath(req.WD), + Tags: req.Tags, + Patterns: append([]string(nil), req.Patterns...), + NeedDeps: req.NeedDeps, + SkipCompiled: req.SkipCompiled, + Go: runtime.Version(), } key, err := hashGob(sumReq) if err != nil { diff --git a/internal/loader/loader.go b/internal/loader/loader.go index e26747b..a507758 100644 --- a/internal/loader/loader.go +++ b/internal/loader/loader.go @@ -93,6 +93,7 @@ type PackageLoadRequest struct { LoaderMode Mode Fset *token.FileSet ParseFile ParseFileFunc + Discovery *DiscoverySnapshot } type PackageLoadResult struct { diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index b95a514..bd2bc8b 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -32,10 +32,10 @@ type outputCacheCandidate struct { outputPath string } -func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, bool) { +func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, patterns []string, opts *GenerateOptions) (map[string]outputCacheCandidate, []GenerateResult, *loader.DiscoverySnapshot, bool) { if !outputCacheEnabled(ctx, wd, env) { debugf(ctx, "generate.output_cache=disabled") - return nil, nil, false + return nil, nil, nil, false } rootResult, err := loader.New().LoadRootGraph(withLoaderTiming(ctx), loader.RootLoadRequest{ WD: wd, @@ -51,7 +51,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa } else { debugf(ctx, "generate.output_cache=no_roots") } - return nil, nil, false + return nil, nil, nil, false } candidates := make(map[string]outputCacheCandidate, len(rootResult.Packages)) results := make([]GenerateResult, 0, len(rootResult.Packages)) @@ -59,17 +59,17 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa outDir, err := detectOutputDir(pkg.GoFiles) if err != nil { debugf(ctx, "generate.output_cache=bad_output_dir") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } key, err := outputCacheKey(wd, opts, pkg) if err != nil { debugf(ctx, "generate.output_cache=key_error") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } path, err := outputCachePath(env, key) if err != nil { debugf(ctx, "generate.output_cache=path_error") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } candidates[pkg.PkgPath] = outputCacheCandidate{ path: path, @@ -78,7 +78,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa entry, ok := readOutputCache(path) if !ok { debugf(ctx, "generate.output_cache=miss") - return candidates, nil, false + return candidates, nil, rootResult.Discovery, false } results = append(results, GenerateResult{ PkgPath: pkg.PkgPath, @@ -87,7 +87,7 @@ func prepareGenerateOutputCache(ctx context.Context, wd string, env []string, pa }) } debugf(ctx, "generate.output_cache=hit") - return candidates, results, len(results) == len(rootResult.Packages) + return candidates, results, rootResult.Discovery, len(results) == len(rootResult.Packages) } func writeGenerateOutputCache(candidates map[string]outputCacheCandidate, generated []GenerateResult) { diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 4350baa..2e9c428 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -254,7 +254,7 @@ type Field struct { // takes precedence. func Load(ctx context.Context, wd string, env []string, tags string, patterns []string) (*Info, []error) { loadStart := time.Now() - pkgs, errs := load(ctx, wd, env, tags, patterns) + pkgs, errs := load(ctx, wd, env, tags, patterns, nil) logTiming(ctx, "load.packages", loadStart) if len(errs) > 0 { return nil, errs @@ -361,7 +361,7 @@ func Load(ctx context.Context, wd string, env []string, tags string, patterns [] // env is nil or empty, it is interpreted as an empty set of variables. // In case of duplicate environment variables, the last one in the list // takes precedence. -func load(ctx context.Context, wd string, env []string, tags string, patterns []string) ([]*packages.Package, []error) { +func load(ctx context.Context, wd string, env []string, tags string, patterns []string, discovery *loader.DiscoverySnapshot) ([]*packages.Package, []error) { fset := token.NewFileSet() loaderMode := effectiveLoaderMode(ctx, wd, env) parseStats := &parseFileStats{} @@ -374,6 +374,7 @@ func load(ctx context.Context, wd string, env []string, tags string, patterns [] Mode: packages.LoadAllSyntax, LoaderMode: loaderMode, Fset: fset, + Discovery: discovery, ParseFile: func(fset *token.FileSet, filename string, src []byte) (*ast.File, error) { start := time.Now() file, err := parser.ParseFile(fset, filename, src, parser.ParseComments|parser.SkipObjectResolution) diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 1c44eba..9f5bb9e 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -102,12 +102,12 @@ func Generate(ctx context.Context, wd string, env []string, patterns []string, o if opts == nil { opts = &GenerateOptions{} } - cacheCandidates, cached, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) + cacheCandidates, cached, discovery, ok := prepareGenerateOutputCache(ctx, wd, env, patterns, opts) if ok { return cached, nil } loadStart := time.Now() - pkgs, errs := load(ctx, wd, env, opts.Tags, patterns) + pkgs, errs := load(ctx, wd, env, opts.Tags, patterns, discovery) logTiming(ctx, "generate.load", loadStart) if len(errs) > 0 { return nil, errs From cf528798871c8089e1d4dec084900115bfdcf035 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Tue, 17 Mar 2026 02:50:43 -0500 Subject: [PATCH 79/82] style: format loader discovery changes --- internal/loader/custom.go | 10 +++++----- internal/loader/discovery.go | 10 +++++----- internal/loader/discovery_cache.go | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/internal/loader/custom.go b/internal/loader/custom.go index 112fe85..b4532af 100644 --- a/internal/loader/custom.go +++ b/internal/loader/custom.go @@ -174,11 +174,11 @@ func validateTouchedPackagesCustom(ctx context.Context, req TouchedValidationReq func loadRootGraphCustom(ctx context.Context, req RootLoadRequest) (*RootLoadResult, error) { meta, discoveryDuration, err := loadCustomMeta(ctx, goListRequest{ - WD: req.WD, - Env: req.Env, - Tags: req.Tags, - Patterns: req.Patterns, - NeedDeps: req.NeedDeps, + WD: req.WD, + Env: req.Env, + Tags: req.Tags, + Patterns: req.Patterns, + NeedDeps: req.NeedDeps, SkipCompiled: true, }) if err != nil { diff --git a/internal/loader/discovery.go b/internal/loader/discovery.go index e416e95..bccfd93 100644 --- a/internal/loader/discovery.go +++ b/internal/loader/discovery.go @@ -27,11 +27,11 @@ import ( ) type goListRequest struct { - WD string - Env []string - Tags string - Patterns []string - NeedDeps bool + WD string + Env []string + Tags string + Patterns []string + NeedDeps bool SkipCompiled bool } diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 1a93bdf..1151853 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -127,13 +127,13 @@ func discoveryCachePath(req goListRequest) (string, error) { return "", err } sumReq := struct { - Version int - WD string - Tags string - Patterns []string - NeedDeps bool + Version int + WD string + Tags string + Patterns []string + NeedDeps bool SkipCompiled bool - Go string + Go string }{ Version: discoveryCacheVersion, WD: canonicalLoaderPath(req.WD), From b0177a66b4c7628cd9cf2e639df1c77da8838231 Mon Sep 17 00:00:00 2001 From: jungooji Date: Mon, 30 Mar 2026 05:16:51 +0900 Subject: [PATCH 80/82] perf: content-hash based cache keys for CI compatibility (#5) Replace mtime-based cache invalidation with content hashing: - artifact_cache: use SHA-256 of file content instead of ModTime - discovery_cache: use content hash for file matching, add WIRE_DISCOVERY_CACHE_DIR env var - Bump cache versions (artifact v4, discovery v4) This enables wire cache to work correctly in CI environments where file mtimes are not preserved across runs (e.g., S3 cache restore, git checkout). Co-authored-by: Claude Opus 4.6 (1M context) --- internal/loader/artifact_cache.go | 57 ++++++++++++++++-------------- internal/loader/discovery_cache.go | 43 ++++++++++++++-------- 2 files changed, 60 insertions(+), 40 deletions(-) diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index e920d5a..ff3f753 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -20,10 +20,10 @@ import ( "encoding/hex" "go/token" "go/types" + "io" "os" "path/filepath" "runtime" - "strconv" "golang.org/x/tools/go/gcexportdata" ) @@ -62,7 +62,7 @@ func loaderArtifactPath(env []string, meta *packageMeta, isLocal bool) (string, func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { sum := sha256.New() - sum.Write([]byte("wire-loader-artifact-v3\n")) + sum.Write([]byte("wire-loader-artifact-v4\n")) sum.Write([]byte(runtime.Version())) sum.Write([]byte{'\n'}) sum.Write([]byte(meta.ImportPath)) @@ -73,26 +73,15 @@ func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { sum.Write([]byte(meta.Export)) sum.Write([]byte{'\n'}) if meta.Export != "" { - info, err := os.Stat(meta.Export) + h, err := hashFileContent(meta.Export) if err != nil { return "", err } - sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + sum.Write([]byte(h)) sum.Write([]byte{'\n'}) } else { - for _, name := range metaFiles(meta) { - info, err := os.Stat(name) - if err != nil { - return "", err - } - sum.Write([]byte(name)) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) - sum.Write([]byte{'\n'}) + if err := hashMetaFiles(sum, metaFiles(meta)); err != nil { + return "", err } } if meta.Error != nil { @@ -101,19 +90,35 @@ func loaderArtifactKey(meta *packageMeta, isLocal bool) (string, error) { } return hex.EncodeToString(sum.Sum(nil)), nil } - for _, name := range metaFiles(meta) { - info, err := os.Stat(name) - if err != nil { - return "", err - } + if err := hashMetaFiles(sum, metaFiles(meta)); err != nil { + return "", err + } + return hex.EncodeToString(sum.Sum(nil)), nil +} + +// hashFileContent returns the hex-encoded SHA-256 of the file content. +func hashFileContent(path string) (string, error) { + data, err := os.ReadFile(path) + if err != nil { + return "", err + } + h := sha256.Sum256(data) + return hex.EncodeToString(h[:]), nil +} + +// hashMetaFiles writes content-based hashes for each file into sum. +func hashMetaFiles(sum io.Writer, names []string) error { + for _, name := range names { sum.Write([]byte(name)) sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.Size(), 10))) - sum.Write([]byte{'\n'}) - sum.Write([]byte(strconv.FormatInt(info.ModTime().UnixNano(), 10))) + h, err := hashFileContent(name) + if err != nil { + return err + } + sum.Write([]byte(h)) sum.Write([]byte{'\n'}) } - return hex.EncodeToString(sum.Sum(nil)), nil + return nil } func readLoaderArtifact(path string, fset *token.FileSet, imports map[string]*types.Package, pkgPath string) (*types.Package, error) { diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 1151853..6d59a22 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -28,10 +28,11 @@ type discoveryLocalPackage struct { } type discoveryFileMeta struct { - Path string - Size int64 - ModTime int64 - IsDir bool + Path string + Size int64 + ModTime int64 // deprecated: kept for gob compat, not used for matching + ContentHash string // sha256 of file content + IsDir bool } type discoveryDirMeta struct { @@ -44,7 +45,7 @@ type discoveryFileFingerprint struct { Hash string } -const discoveryCacheVersion = 3 +const discoveryCacheVersion = 4 func readDiscoveryCache(req goListRequest) (map[string]*packageMeta, bool) { entry, err := loadDiscoveryCacheEntry(req) @@ -121,10 +122,16 @@ func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { return true } +const discoveryCacheDirEnv = "WIRE_DISCOVERY_CACHE_DIR" + func discoveryCachePath(req goListRequest) (string, error) { - base, err := os.UserCacheDir() - if err != nil { - return "", err + dir := os.Getenv(discoveryCacheDirEnv) + if dir == "" { + base, err := os.UserCacheDir() + if err != nil { + return "", err + } + dir = filepath.Join(base, "wire", "discovery-cache") } sumReq := struct { Version int @@ -147,7 +154,7 @@ func discoveryCachePath(req goListRequest) (string, error) { if err != nil { return "", err } - return filepath.Join(base, "wire", "discovery-cache", key+".gob"), nil + return filepath.Join(dir, key+".gob"), nil } func loadDiscoveryCacheEntry(req goListRequest) (*discoveryCacheEntry, error) { @@ -188,11 +195,19 @@ func statDiscoveryFile(path string) (discoveryFileMeta, bool) { if err != nil { return discoveryFileMeta{}, false } + h := "" + if !info.IsDir() { + var err error + h, err = hashFileContent(path) + if err != nil { + return discoveryFileMeta{}, false + } + } return discoveryFileMeta{ - Path: canonicalLoaderPath(path), - Size: info.Size(), - ModTime: info.ModTime().UnixNano(), - IsDir: info.IsDir(), + Path: canonicalLoaderPath(path), + Size: info.Size(), + ContentHash: h, + IsDir: info.IsDir(), }, true } @@ -201,7 +216,7 @@ func matchesDiscoveryFile(fm discoveryFileMeta) bool { if !ok { return false } - return cur.Size == fm.Size && cur.ModTime == fm.ModTime && cur.IsDir == fm.IsDir + return cur.ContentHash == fm.ContentHash && cur.IsDir == fm.IsDir } func statDiscoveryDir(path string) (discoveryDirMeta, bool) { From 8a3ed6d1f0d54fb4d485e1357edfb5609eac2a27 Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Sun, 29 Mar 2026 15:20:47 -0500 Subject: [PATCH 81/82] refactor: share wire cache base directory --- cmd/wire/cache_cmd.go | 31 ++++++----------- cmd/wire/cache_cmd_test.go | 34 ++++++++++++++---- internal/cachepaths/cachepaths.go | 55 ++++++++++++++++++++++++++++++ internal/loader/artifact_cache.go | 13 +++---- internal/loader/discovery_cache.go | 14 ++++---- internal/wire/output_cache.go | 12 ++----- 6 files changed, 105 insertions(+), 54 deletions(-) create mode 100644 internal/cachepaths/cachepaths.go diff --git a/cmd/wire/cache_cmd.go b/cmd/wire/cache_cmd.go index 1bc4560..5be1729 100644 --- a/cmd/wire/cache_cmd.go +++ b/cmd/wire/cache_cmd.go @@ -25,16 +25,15 @@ import ( "strings" "github.com/google/subcommands" + + "github.com/goforj/wire/internal/cachepaths" ) const ( - loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" - outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" - semanticCacheDirEnv = "WIRE_SEMANTIC_CACHE_DIR" + loaderArtifactDirEnv = cachepaths.LoaderArtifactDirEnv + outputCacheDirEnv = cachepaths.OutputCacheDirEnv ) -var osUserCacheDir = os.UserCacheDir - type cacheCmd struct { clear bool } @@ -113,11 +112,11 @@ func (cmd *cacheCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...inter } func wireCacheRoot(env []string) (string, error) { - base, err := osUserCacheDir() + root, err := cachepaths.Root(env) if err != nil { return "", fmt.Errorf("resolve user cache dir: %w", err) } - return filepath.Join(base, "wire"), nil + return root, nil } func clearWireCaches(env []string) ([]string, error) { @@ -148,11 +147,11 @@ func clearWireCaches(env []string) ([]string, error) { } func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { - baseWire := filepath.Join(userCacheDir, "wire") + baseWire := cachepaths.EnvValueDefault(env, cachepaths.BaseDirEnv, filepath.Join(userCacheDir, "wire")) targets := []cacheTarget{ - {name: "loader-artifacts", path: envValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, - {name: "discovery-cache", path: filepath.Join(baseWire, "discovery-cache")}, - {name: "output-cache", path: envValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, + {name: "loader-artifacts", path: cachepaths.EnvValueDefault(env, loaderArtifactDirEnv, filepath.Join(baseWire, "loader-artifacts"))}, + {name: "discovery-cache", path: cachepaths.EnvValueDefault(env, cachepaths.DiscoveryCacheDirEnv, filepath.Join(baseWire, "discovery-cache"))}, + {name: "output-cache", path: cachepaths.EnvValueDefault(env, outputCacheDirEnv, filepath.Join(baseWire, "output-cache"))}, } seen := make(map[string]bool, len(targets)) deduped := make([]cacheTarget, 0, len(targets)) @@ -168,13 +167,3 @@ func wireCacheTargets(env []string, userCacheDir string) []cacheTarget { sort.Slice(deduped, func(i, j int) bool { return deduped[i].name < deduped[j].name }) return deduped } - -func envValueDefault(env []string, key, fallback string) string { - for i := len(env) - 1; i >= 0; i-- { - parts := strings.SplitN(env[i], "=", 2) - if len(parts) == 2 && parts[0] == key && parts[1] != "" { - return parts[1] - } - } - return fallback -} diff --git a/cmd/wire/cache_cmd_test.go b/cmd/wire/cache_cmd_test.go index 578c2aa..1d13acb 100644 --- a/cmd/wire/cache_cmd_test.go +++ b/cmd/wire/cache_cmd_test.go @@ -4,6 +4,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/goforj/wire/internal/cachepaths" ) func TestWireCacheTargetsDefault(t *testing.T) { @@ -26,9 +28,9 @@ func TestWireCacheTargetsDefault(t *testing.T) { func TestWireCacheRoot(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") - old := osUserCacheDir - osUserCacheDir = func() (string, error) { return base, nil } - defer func() { osUserCacheDir = old }() + old := cachepaths.UserCacheDir + cachepaths.UserCacheDir = func() (string, error) { return base, nil } + defer func() { cachepaths.UserCacheDir = old }() got, err := wireCacheRoot(nil) if err != nil { @@ -44,11 +46,12 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") env := []string{ loaderArtifactDirEnv + "=" + filepath.Join(base, "loader"), + cachepaths.DiscoveryCacheDirEnv + "=" + filepath.Join(base, "discovery"), outputCacheDirEnv + "=" + filepath.Join(base, "output"), } got := wireCacheTargets(env, base) want := map[string]string{ - "discovery-cache": filepath.Join(base, "wire", "discovery-cache"), + "discovery-cache": filepath.Join(base, "discovery"), "loader-artifacts": filepath.Join(base, "loader"), "output-cache": filepath.Join(base, "output"), } @@ -59,6 +62,23 @@ func TestWireCacheTargetsRespectOverrides(t *testing.T) { } } +func TestWireCacheTargetsRespectBaseDirOverride(t *testing.T) { + base := filepath.Join(t.TempDir(), "cache") + root := filepath.Join(base, "wire-root") + env := []string{cachepaths.BaseDirEnv + "=" + root} + got := wireCacheTargets(env, base) + want := map[string]string{ + "discovery-cache": filepath.Join(root, "discovery-cache"), + "loader-artifacts": filepath.Join(root, "loader-artifacts"), + "output-cache": filepath.Join(root, "output-cache"), + } + for _, target := range got { + if target.path != want[target.name] { + t.Fatalf("%s path = %q, want %q", target.name, target.path, want[target.name]) + } + } +} + func TestClearWireCachesRemovesTargets(t *testing.T) { base := filepath.Join(t.TempDir(), "cache") env := []string{ @@ -73,9 +93,9 @@ func TestClearWireCachesRemovesTargets(t *testing.T) { t.Fatalf("WriteFile(%q): %v", target.path, err) } } - old := osUserCacheDir - osUserCacheDir = func() (string, error) { return base, nil } - defer func() { osUserCacheDir = old }() + old := cachepaths.UserCacheDir + cachepaths.UserCacheDir = func() (string, error) { return base, nil } + defer func() { cachepaths.UserCacheDir = old }() cleared, err := clearWireCaches(env) if err != nil { diff --git a/internal/cachepaths/cachepaths.go b/internal/cachepaths/cachepaths.go new file mode 100644 index 0000000..f3e3adb --- /dev/null +++ b/internal/cachepaths/cachepaths.go @@ -0,0 +1,55 @@ +package cachepaths + +import ( + "os" + "path/filepath" + "strings" +) + +const ( + BaseDirEnv = "WIRE_CACHE_DIR" + LoaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" + DiscoveryCacheDirEnv = "WIRE_DISCOVERY_CACHE_DIR" + OutputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" +) + +var UserCacheDir = os.UserCacheDir + +func Root(env []string) (string, error) { + if dir := envValue(env, BaseDirEnv); dir != "" { + return filepath.Clean(dir), nil + } + base, err := UserCacheDir() + if err != nil { + return "", err + } + return filepath.Join(base, "wire"), nil +} + +func Dir(env []string, specificEnv, name string) (string, error) { + if dir := envValue(env, specificEnv); dir != "" { + return filepath.Clean(dir), nil + } + root, err := Root(env) + if err != nil { + return "", err + } + return filepath.Join(root, name), nil +} + +func EnvValueDefault(env []string, key, fallback string) string { + if value := envValue(env, key); value != "" { + return value + } + return fallback +} + +func envValue(env []string, key string) string { + for i := len(env) - 1; i >= 0; i-- { + name, value, ok := strings.Cut(env[i], "=") + if ok && name == key && value != "" { + return value + } + } + return "" +} diff --git a/internal/loader/artifact_cache.go b/internal/loader/artifact_cache.go index ff3f753..a6dfdb1 100644 --- a/internal/loader/artifact_cache.go +++ b/internal/loader/artifact_cache.go @@ -26,11 +26,13 @@ import ( "runtime" "golang.org/x/tools/go/gcexportdata" + + "github.com/goforj/wire/internal/cachepaths" ) const ( loaderArtifactEnv = "WIRE_LOADER_ARTIFACTS" - loaderArtifactDirEnv = "WIRE_LOADER_ARTIFACT_DIR" + loaderArtifactDirEnv = cachepaths.LoaderArtifactDirEnv ) func loaderArtifactEnabled(env []string) bool { @@ -38,14 +40,7 @@ func loaderArtifactEnabled(env []string) bool { } func loaderArtifactDir(env []string) (string, error) { - if dir := envValue(env, loaderArtifactDirEnv); dir != "" { - return dir, nil - } - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - return filepath.Join(base, "wire", "loader-artifacts"), nil + return cachepaths.Dir(env, loaderArtifactDirEnv, "loader-artifacts") } func loaderArtifactPath(env []string, meta *packageMeta, isLocal bool) (string, error) { diff --git a/internal/loader/discovery_cache.go b/internal/loader/discovery_cache.go index 6d59a22..3381041 100644 --- a/internal/loader/discovery_cache.go +++ b/internal/loader/discovery_cache.go @@ -11,6 +11,8 @@ import ( "path/filepath" "runtime" "sort" + + "github.com/goforj/wire/internal/cachepaths" ) type discoveryCacheEntry struct { @@ -122,16 +124,12 @@ func validateDiscoveryCacheEntry(entry *discoveryCacheEntry) bool { return true } -const discoveryCacheDirEnv = "WIRE_DISCOVERY_CACHE_DIR" +const discoveryCacheDirEnv = cachepaths.DiscoveryCacheDirEnv func discoveryCachePath(req goListRequest) (string, error) { - dir := os.Getenv(discoveryCacheDirEnv) - if dir == "" { - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - dir = filepath.Join(base, "wire", "discovery-cache") + dir, err := cachepaths.Dir(req.Env, discoveryCacheDirEnv, "discovery-cache") + if err != nil { + return "", err } sumReq := struct { Version int diff --git a/internal/wire/output_cache.go b/internal/wire/output_cache.go index bd2bc8b..42fcaa4 100644 --- a/internal/wire/output_cache.go +++ b/internal/wire/output_cache.go @@ -14,11 +14,12 @@ import ( "golang.org/x/tools/go/packages" + "github.com/goforj/wire/internal/cachepaths" "github.com/goforj/wire/internal/loader" ) const ( - outputCacheDirEnv = "WIRE_OUTPUT_CACHE_DIR" + outputCacheDirEnv = cachepaths.OutputCacheDirEnv outputCacheEnabledEnv = "WIRE_OUTPUT_CACHE" ) @@ -122,14 +123,7 @@ func outputCachePath(env []string, key string) (string, error) { } func outputCacheDir(env []string) (string, error) { - if dir := envValue(env, outputCacheDirEnv); dir != "" { - return dir, nil - } - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - return filepath.Join(base, "wire", "output-cache"), nil + return cachepaths.Dir(env, outputCacheDirEnv, "output-cache") } func readOutputCache(path string) (*outputCacheEntry, bool) { From 275e27107d9d92a5dc399a78358deba7f788ee0d Mon Sep 17 00:00:00 2001 From: Chris Miles Date: Wed, 1 Apr 2026 00:03:18 -0500 Subject: [PATCH 82/82] docs: add readme in internal for future travelers --- internal/README.md | 213 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 internal/README.md diff --git a/internal/README.md b/internal/README.md new file mode 100644 index 0000000..1134d90 --- /dev/null +++ b/internal/README.md @@ -0,0 +1,213 @@ +# Internal Package Guide + +This directory holds the implementation behind the `wire` CLI and library. +If you are new to the codebase, the two packages that matter most are: + +- `internal/wire`: parses injector/provider-set source and generates `wire_gen.go` +- `internal/loader`: loads package graphs, supports the custom loader, and manages loader-side caches + +Everything else is either small support code or repo maintenance helpers. + +## High-Level Flow + +For `wire gen`, the rough flow is: + +1. `cmd/wire` parses flags and calls `wire.Generate(...)` +2. `internal/wire` tries the output cache fast path +3. `internal/loader` loads the root graph and typed package graph +4. `internal/wire/parse.go` walks source/type information and builds Wire's internal model +5. `internal/wire/wire.go` generates formatted Go output +6. `cmd/wire` writes `wire_gen.go` + +If you are debugging behavior, start in: + +- [`wire/wire.go`](./wire/wire.go) +- [`wire/parse.go`](./wire/parse.go) +- [`loader/loader.go`](./loader/loader.go) +- [`loader/custom.go`](./loader/custom.go) + +## Package Overview + +### `internal/wire` + +This is the core implementation of Wire's analysis and code generation. + +Important responsibilities: + +- loading packages through the loader abstraction +- parsing `wire.NewSet`, `wire.Bind`, `wire.Struct`, `wire.FieldsOf`, and injector functions +- building the internal provider graph model +- validating dependency graphs +- rendering generated Go source +- managing the output cache + +Important files: + +- [`wire/wire.go`](./wire/wire.go) + Main generation entry point. Defines `Generate`, `GenerateResult`, and the final generation loop. +- [`wire/parse.go`](./wire/parse.go) + The biggest conceptual file in the repo. Defines the internal model (`ProviderSet`, `Provider`, `Value`, `Field`, etc.) and parses source/type information into that model. +- [`wire/analyze.go`](./wire/analyze.go) + Performs graph validation and dependency analysis once provider sets are parsed. +- [`wire/output_cache.go`](./wire/output_cache.go) + Output cache read/write and key construction. +- [`wire/load_debug.go`](./wire/load_debug.go) + Debug/timing summaries for loaded package graphs. +- [`wire/timing.go`](./wire/timing.go) + Timing/debug plumbing for the `wire` package. +- [`wire/loader_timing_bridge.go`](./wire/loader_timing_bridge.go) + Connects loader timing output into wire timing output. +- [`wire/errors.go`](./wire/errors.go) + Error collection and formatting helpers. +- [`wire/copyast.go`](./wire/copyast.go) + AST copying helpers used during generation. +- [`wire/loader_validation.go`](./wire/loader_validation.go) + Loader-mode-related validation helpers. + +Useful mental model: + +- `parse.go` turns package syntax/types into a Wire-specific graph model +- `analyze.go` validates that graph +- `wire.go` emits Go code from that graph + +### `internal/loader` + +This package abstracts package loading and is the main performance-sensitive layer. + +Important responsibilities: + +- choosing between the custom loader and `go/packages` fallback +- root graph loading vs typed package loading +- discovery via `go list` +- discovery cache +- loader artifact cache +- touched-package validation +- loader timings and fallback reasons + +Important files: + +- [`loader/loader.go`](./loader/loader.go) + Public loader API and shared request/result structs. This is the best entry point for understanding loader responsibilities. +- [`loader/custom.go`](./loader/custom.go) + The custom loader implementation. This is the most performance-critical file in the repo. +- [`loader/fallback.go`](./loader/fallback.go) + `go/packages` fallback implementation and fallback reason handling. +- [`loader/discovery.go`](./loader/discovery.go) + Runs `go list`, decodes package metadata, and populates the discovery cache. +- [`loader/discovery_cache.go`](./loader/discovery_cache.go) + Discovery cache storage and invalidation. +- [`loader/artifact_cache.go`](./loader/artifact_cache.go) + Loader artifact cache keying and read/write helpers. +- [`loader/mode.go`](./loader/mode.go) + Loader mode selection helpers. +- [`loader/timing.go`](./loader/timing.go) + Timing/debug plumbing for the loader package. + +Useful mental model: + +- `discovery.go` answers "what packages/files/imports exist?" +- `custom.go` answers "how do we build the package graph and type info efficiently?" +- `artifact_cache.go` is the typed-package reuse layer +- `fallback.go` is the correctness backstop + +### `internal/cachepaths` + +Small shared helper package for Wire-managed cache directories. + +Responsibilities: + +- resolve the shared cache root +- resolve specific cache directories +- support shared and specific env var overrides + +Important file: + +- [`cachepaths/cachepaths.go`](./cachepaths/cachepaths.go) + +This package exists to keep cache path policy centralized instead of duplicated across loader, output cache, and the `wire cache` command. + +## Internal Data Structures + +These are the main structs worth learning first. + +### In `internal/wire` + +- `ProviderSet` + The central Wire model. Represents a set built from `wire.NewSet(...)` or `wire.Build(...)`. +- `Provider` + A constructor source, either a function or a named struct provider. +- `IfaceBinding` + A `wire.Bind(...)` relationship. +- `Value` + A `wire.Value(...)` source. +- `Field` + A `wire.FieldsOf(...)` source. +- `InjectorArgs` / `InjectorArg` + Injector function argument modeling. + +If you understand `ProviderSet`, most of `parse.go` becomes easier to follow. + +### In `internal/loader` + +- `RootLoadRequest` / `RootLoadResult` + Used for lightweight root graph loading, primarily for cache lookup and root discovery. +- `PackageLoadRequest` / `PackageLoadResult` + Used for full typed package loading. +- `LazyLoadRequest` / `LazyLoadResult` + Used for package-targeted typed loading. +- `TouchedValidationRequest` / `TouchedValidationResult` + Used to validate whether a touched local package can stay on the custom path. +- `DiscoverySnapshot` + Captures `go list` metadata so later phases can reuse discovery without rerunning it. +- `packageMeta` + Internal metadata from `go list`. This is the custom loader's raw package description. +- `customTypedGraphLoader` + Main stateful custom loader for building typed package graphs. + +## Empty or Transitional Directories + +- `internal/cachestore` +- `internal/semanticcache` + +These directories currently exist but do not contain active implementation files. +Treat them as inactive unless they are populated later. + +## Repo Maintenance Files + +There are also a few internal scripts/data files that support repo maintenance: + +- [`alldeps`](./alldeps) + Dependency allowlist/check input. +- [`runtests.sh`](./runtests.sh) + Repo test runner used in CI/dev workflows. +- [`listdeps.sh`](./listdeps.sh) + Dependency listing helper. +- [`check_api_change.sh`](./check_api_change.sh) + API change check helper. + +## Suggested Reading Order + +If you are trying to understand the codebase quickly: + +1. [`wire/wire.go`](./wire/wire.go) +2. [`loader/loader.go`](./loader/loader.go) +3. [`loader/custom.go`](./loader/custom.go) +4. [`wire/parse.go`](./wire/parse.go) +5. [`wire/analyze.go`](./wire/analyze.go) +6. [`wire/output_cache.go`](./wire/output_cache.go) +7. [`loader/discovery.go`](./loader/discovery.go) +8. [`loader/artifact_cache.go`](./loader/artifact_cache.go) + +## Practical Notes For New Readers + +- If a behavior issue involves parsing provider sets, start in `internal/wire/parse.go`. +- If a behavior issue involves performance, cache hits, or package loading, start in `internal/loader/custom.go`. +- If a behavior issue involves "why did we skip generation?" or "why did generation return instantly?", check `internal/wire/output_cache.go`. +- If a behavior issue only appears in one environment or workspace, inspect cache path resolution in `internal/cachepaths/cachepaths.go` and discovery behavior in `internal/loader/discovery.go`. + +This repo has two real centers of complexity: + +- the Wire semantic model in `internal/wire` +- the package loading/caching machinery in `internal/loader` + +Most other internal code exists to support those two areas.