diff --git a/internal/wire/analyze.go b/internal/wire/analyze.go index 9650ef1..ccea2c8 100644 --- a/internal/wire/analyze.go +++ b/internal/wire/analyze.go @@ -52,6 +52,10 @@ type call struct { pkg *types.Package name string + // methodExprRecv is the receiver type for a method expression provider. + // It is nil for package-level function providers. + methodExprRecv types.Type + // args is a list of arguments to call the provider with. Each element is: // a) one of the givens (args[i] < len(given)), // b) the result of a previous provider call (args[i] >= len(given)) @@ -196,16 +200,17 @@ dfs: } } calls = append(calls, call{ - kind: kind, - pkg: p.Pkg, - name: p.Name, - args: args, - varargs: p.Varargs, - fieldNames: fieldNames, - ins: ins, - out: curr.t, - hasCleanup: p.HasCleanup, - hasErr: p.HasErr, + kind: kind, + pkg: p.Pkg, + name: p.Name, + methodExprRecv: p.MethodExprRecv, + args: args, + varargs: p.Varargs, + fieldNames: fieldNames, + ins: ins, + out: curr.t, + hasCleanup: p.HasCleanup, + hasErr: p.HasErr, }) case pv.IsValue(): v := pv.Value() diff --git a/internal/wire/parse.go b/internal/wire/parse.go index 2e9c428..e8efcbf 100644 --- a/internal/wire/parse.go +++ b/internal/wire/parse.go @@ -156,6 +156,10 @@ type Provider struct { // Name is the name of the Go object. Name string + // MethodExprRecv is the receiver type for a method expression provider. + // It is nil for package-level function providers. + MethodExprRecv types.Type + // Pos is the source position of the func keyword or type spec // defining this provider. Pos token.Pos @@ -718,6 +722,12 @@ func valueSpecForVar(fset *token.FileSet, files []*ast.File, obj *types.Var) *as func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Expr, varName string) (interface{}, []error) { exprPos := oc.fset.Position(expr.Pos()) expr = astutil.Unparen(expr) + if sel, ok := expr.(*ast.SelectorExpr); ok { + if selInfo := info.Selections[sel]; selInfo != nil && selInfo.Kind() == types.MethodExpr { + p, errs := processMethodExprProvider(oc.fset, info, sel, selInfo) + return p, notePositionAll(exprPos, errs) + } + } if obj := qualifiedIdentObject(info, expr); obj != nil { item, errs := oc.get(obj) return item, mapErrors(errs, func(err error) error { @@ -877,20 +887,37 @@ func qualifiedIdentObject(info *types.Info, expr ast.Expr) types.Object { func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []error) { sig := fn.Type().(*types.Signature) fpos := fn.Pos() + return processProviderSignature(fset, fn.Pkg(), fn.Name(), nil, fpos, sig) +} + +func processMethodExprProvider(fset *token.FileSet, info *types.Info, expr *ast.SelectorExpr, sel *types.Selection) (*Provider, []error) { + obj, ok := sel.Obj().(*types.Func) + if !ok { + return nil, []error{fmt.Errorf("%s is not a function", expr.Sel.Name)} + } + sig, ok := info.TypeOf(expr).(*types.Signature) + if !ok { + return nil, []error{fmt.Errorf("method expression %s does not have a function signature", expr.Sel.Name)} + } + return processProviderSignature(fset, obj.Pkg(), obj.Name(), sel.Recv(), expr.Pos(), sig) +} + +func processProviderSignature(fset *token.FileSet, pkg *types.Package, name string, recv types.Type, pos token.Pos, sig *types.Signature) (*Provider, []error) { providerSig, err := funcOutput(sig) if err != nil { - return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("wrong signature for provider %s: %v", fn.Name(), err))} + return nil, []error{notePosition(fset.Position(pos), fmt.Errorf("wrong signature for provider %s: %v", name, err))} } params := sig.Params() provider := &Provider{ - Pkg: fn.Pkg(), - Name: fn.Name(), - Pos: fn.Pos(), - Args: make([]ProviderInput, params.Len()), - Varargs: sig.Variadic(), - Out: []types.Type{providerSig.out}, - HasCleanup: providerSig.cleanup, - HasErr: providerSig.err, + Pkg: pkg, + Name: name, + MethodExprRecv: recv, + Pos: pos, + Args: make([]ProviderInput, params.Len()), + Varargs: sig.Variadic(), + Out: []types.Type{providerSig.out}, + HasCleanup: providerSig.cleanup, + HasErr: providerSig.err, } for i := 0; i < params.Len(); i++ { provider.Args[i] = ProviderInput{ @@ -898,7 +925,7 @@ func processFuncProvider(fset *token.FileSet, fn *types.Func) (*Provider, []erro } for j := 0; j < i; j++ { if types.Identical(provider.Args[i].Type, provider.Args[j].Type) { - return nil, []error{notePosition(fset.Position(fpos), fmt.Errorf("provider has multiple parameters of type %s", types.TypeString(provider.Args[j].Type, nil)))} + return nil, []error{notePosition(fset.Position(pos), fmt.Errorf("provider has multiple parameters of type %s", types.TypeString(provider.Args[j].Type, nil)))} } } } diff --git a/internal/wire/testdata/MethodExprProvider/foo/foo.go b/internal/wire/testdata/MethodExprProvider/foo/foo.go new file mode 100644 index 0000000..4e53baf --- /dev/null +++ b/internal/wire/testdata/MethodExprProvider/foo/foo.go @@ -0,0 +1,41 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "fmt" + +func main() { + db, err := initDB() + if err != nil { + panic(err) + } + fmt.Println(db.DSN) +} + +type Options struct { + DSN string +} + +type DB struct { + DSN string +} + +func provideOptions() *Options { + return &Options{DSN: "postgres://wire"} +} + +func (o *Options) ToDB() (*DB, error) { + return &DB{DSN: o.DSN}, nil +} diff --git a/internal/wire/testdata/MethodExprProvider/foo/wire.go b/internal/wire/testdata/MethodExprProvider/foo/wire.go new file mode 100644 index 0000000..cfe9f31 --- /dev/null +++ b/internal/wire/testdata/MethodExprProvider/foo/wire.go @@ -0,0 +1,28 @@ +// Copyright 2018 The Wire Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build wireinject +// +build wireinject + +package main + +import "github.com/goforj/wire" + +func initDB() (*DB, error) { + wire.Build( + provideOptions, + (*Options).ToDB, + ) + return nil, nil +} diff --git a/internal/wire/testdata/MethodExprProvider/pkg b/internal/wire/testdata/MethodExprProvider/pkg new file mode 100644 index 0000000..f7a5c8c --- /dev/null +++ b/internal/wire/testdata/MethodExprProvider/pkg @@ -0,0 +1 @@ +example.com/foo diff --git a/internal/wire/testdata/MethodExprProvider/want/program_out.txt b/internal/wire/testdata/MethodExprProvider/want/program_out.txt new file mode 100644 index 0000000..601359b --- /dev/null +++ b/internal/wire/testdata/MethodExprProvider/want/program_out.txt @@ -0,0 +1 @@ +postgres://wire diff --git a/internal/wire/testdata/MethodExprProvider/want/wire_gen.go b/internal/wire/testdata/MethodExprProvider/want/wire_gen.go new file mode 100644 index 0000000..c18a2db --- /dev/null +++ b/internal/wire/testdata/MethodExprProvider/want/wire_gen.go @@ -0,0 +1,18 @@ +// Code generated by Wire. DO NOT EDIT. + +//go:generate go run -mod=mod github.com/goforj/wire/cmd/wire +//go:build !wireinject +// +build !wireinject + +package main + +// Injectors from wire.go: + +func initDB() (*DB, error) { + options := provideOptions() + db, err := (*Options).ToDB(options) + if err != nil { + return nil, err + } + return db, nil +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 9f5bb9e..6f848fa 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -720,7 +720,7 @@ func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputS ig.p(", %s", ig.errVar) } ig.p(" := ") - ig.p("%s(", ig.g.qualifiedID(c.pkg.Name(), c.pkg.Path(), c.name)) + ig.p("%s(", ig.funcProviderExpr(c)) for i, a := range c.args { if i > 0 { ig.p(", ") @@ -750,6 +750,17 @@ func (ig *injectorGen) funcProviderCall(lname string, c *call, injectSig outputS } } +func (ig *injectorGen) funcProviderExpr(c *call) string { + if c.methodExprRecv == nil { + return ig.g.qualifiedID(c.pkg.Name(), c.pkg.Path(), c.name) + } + recv := types.TypeString(c.methodExprRecv, ig.g.qualifyPkg) + if _, ok := c.methodExprRecv.(*types.Pointer); ok { + recv = "(" + recv + ")" + } + return recv + "." + c.name +} + func (ig *injectorGen) structProviderCall(lname string, c *call) { ig.p("\t%s", lname) ig.p(" := ")