diff --git a/fftest/vars.go b/fftest/vars.go index cd47d86..d54d922 100644 --- a/fftest/vars.go +++ b/fftest/vars.go @@ -2,6 +2,7 @@ package fftest import ( "errors" + "flag" "reflect" "strings" "testing" @@ -93,3 +94,18 @@ func Compare(t *testing.T, want, have *Vars) { } } } + +type Custom struct { + v string +} + +var _ flag.Value = (*Custom)(nil) + +func (c *Custom) Set(s string) error { + c.v = s + return nil +} + +func (c *Custom) String() string { + return c.v +} diff --git a/flag_set.go b/flag_set.go index 4107631..06d7b9e 100644 --- a/flag_set.go +++ b/flag_set.go @@ -713,18 +713,26 @@ func (fs *FlagSet) AddStruct(val any) error { // Produce a flag.Value representing the field. { - var ( - fieldValAddr = fieldVal.Addr() - fieldValAddrTyp = fieldValAddr.Type() - fieldValAddrIface = fieldValAddr.Interface() - flagValueElemTyp = reflect.TypeOf((*flag.Value)(nil)).Elem() - ) - if fieldValAddrTyp.Implements(flagValueElemTyp) { - // The field implements flag.Value, we can use it directly. - cfg.Value = fieldValAddrIface.(flag.Value) + flagValueElemTyp := reflect.TypeOf((*flag.Value)(nil)).Elem() + if fieldVal.IsValid() && fieldVal.CanAddr() && fieldVal.Addr().Type().Implements(flagValueElemTyp) { + cfg.Value = fieldVal.Addr().Interface().(flag.Value) + if def != "" { + if err := cfg.Value.Set(def); err != nil { + return fmt.Errorf("%s: set default: %w", fieldName, err) + } + } + } else if fieldVal.IsValid() && fieldVal.Type().Implements(flagValueElemTyp) && fieldVal.CanInterface() { + if fieldVal.IsNil() { + return fmt.Errorf("%s: nil value for otherwise valid type (%s)", fieldName, fieldTyp.Type.String()) + } + cfg.Value = fieldVal.Interface().(flag.Value) + if def != "" { + if err := cfg.Value.Set(def); err != nil { + return fmt.Errorf("%s: set default: %w", fieldName, err) + } + } } else { - // Try to construct a new flag value. - v, err := ffval.NewValueReflect(fieldValAddrIface, def) + v, err := ffval.NewValueReflect(fieldVal.Addr().Interface(), def) if err != nil { return fmt.Errorf("%s: %w", fieldName, err) } diff --git a/flag_set_test.go b/flag_set_test.go index 928c6d7..583e15d 100644 --- a/flag_set_test.go +++ b/flag_set_test.go @@ -571,6 +571,59 @@ func TestFlagSet_StructEmbedded(t *testing.T) { } } +func TestFlagSet_StructCustom(t *testing.T) { + t.Parallel() + + type A struct { + Foo fftest.Custom `ff:"long=foo, default=abc"` + Bar *fftest.Custom `ff:"long=bar, default=def"` + } + + t.Run("valid no args", func(t *testing.T) { + fs := ff.NewFlagSet(t.Name()) + aflags := A{Bar: &fftest.Custom{}} + if err := fs.AddStruct(&aflags); err != nil { + t.Fatalf("AddStruct(&aflags): %v", err) + } + if err := ff.Parse(fs, []string{}); err != nil { + t.Fatal(err) + } + if want, have := "abc", aflags.Foo.String(); want != have { + t.Errorf("Foo: want %q, have %q", want, have) + } + if want, have := "def", aflags.Bar.String(); want != have { + t.Errorf("Bar: want %q, have %q", want, have) + } + }) + + t.Run("valid with args", func(t *testing.T) { + fs := ff.NewFlagSet(t.Name()) + aflags := A{Bar: &fftest.Custom{}} + if err := fs.AddStruct(&aflags); err != nil { + t.Fatalf("AddStruct(&aflags): %v", err) + } + if err := ff.Parse(fs, []string{"--foo=123", "--bar", "456"}); err != nil { + t.Fatal(err) + } + if want, have := "123", aflags.Foo.String(); want != have { + t.Errorf("Foo: want %q, have %q", want, have) + } + if want, have := "456", aflags.Bar.String(); want != have { + t.Errorf("Bar: want %q, have %q", want, have) + } + }) + + t.Run("invalid", func(t *testing.T) { + fs := ff.NewFlagSet(t.Name()) + aflags := A{} // nil Bar + if err := fs.AddStruct(&aflags); err == nil { + t.Fatalf("AddStruct(&aflags): wanted err, got none") + } else { + t.Logf("AddStruct(&aflags): got expected error (%v)", err) + } + }) +} + func TestFlagSet_Std(t *testing.T) { t.Parallel()