diff --git a/bool.go b/bool.go index c4c5c0bf..5f02373e 100644 --- a/bool.go +++ b/bool.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // optional interface to indicate boolean flags that can be // supplied without "=value" text @@ -19,8 +22,11 @@ func newBoolValue(val bool, p *bool) *boolValue { func (b *boolValue) Set(s string) error { v, err := strconv.ParseBool(s) + if err != nil { + return errors.New("must be true or false") + } *b = boolValue(v) - return err + return nil } func (b *boolValue) Type() string { diff --git a/duration.go b/duration.go index e9debef8..f5d78b74 100644 --- a/duration.go +++ b/duration.go @@ -1,6 +1,7 @@ package pflag import ( + "errors" "time" ) @@ -14,8 +15,11 @@ func newDurationValue(val time.Duration, p *time.Duration) *durationValue { func (d *durationValue) Set(s string) error { v, err := time.ParseDuration(s) + if err != nil { + return errors.New(`must be a duration like "30s" or "5m"`) + } *d = durationValue(v) - return err + return nil } func (d *durationValue) Type() string { diff --git a/flag_test.go b/flag_test.go index b367f7d2..282522a3 100644 --- a/flag_test.go +++ b/flag_test.go @@ -395,7 +395,7 @@ func testParse(f *FlagSet, t *testing.T) { } // Test invalid err = f.Parse([]string{"--bool=abcdefg"}) - expectedErr = `invalid argument "abcdefg" for "--bool" flag: strconv.ParseBool: parsing "abcdefg": invalid syntax` + expectedErr = `invalid argument "abcdefg" for "--bool" flag: must be true or false` if err == nil { t.Error("parse did not fail for invalid argument") } @@ -661,6 +661,43 @@ func TestParse(t *testing.T) { testParse(GetCommandLine(), t) } +// TestInvalidArgumentMessages locks in the human-readable suffixes returned +// when typed flags fail to parse their input, replacing the prior raw stdlib +// errors (e.g. `strconv.ParseBool: parsing "x": invalid syntax`). +func TestInvalidArgumentMessages(t *testing.T) { + cases := []struct { + flag string + register func(*FlagSet) + raw string + want string + }{ + {"bool", func(f *FlagSet) { f.Bool("bool", false, "") }, "x", + `invalid argument "x" for "--bool" flag: must be true or false`}, + {"int", func(f *FlagSet) { f.Int("int", 0, "") }, "x", + `invalid argument "x" for "--int" flag: must be an integer`}, + {"uint", func(f *FlagSet) { f.Uint("uint", 0, "") }, "-1", + `invalid argument "-1" for "--uint" flag: must be a non-negative integer`}, + {"float64", func(f *FlagSet) { f.Float64("float64", 0, "") }, "x", + `invalid argument "x" for "--float64" flag: must be a number`}, + {"duration", func(f *FlagSet) { f.Duration("duration", 0, "") }, "soon", + `invalid argument "soon" for "--duration" flag: must be a duration like "30s" or "5m"`}, + } + for _, tc := range cases { + t.Run(tc.flag, func(t *testing.T) { + f := NewFlagSet("test", ContinueOnError) + f.SetOutput(ioutil.Discard) + tc.register(f) + err := f.Parse([]string{"--" + tc.flag + "=" + tc.raw}) + if err == nil { + t.Fatalf("expected error parsing --%s=%q, got nil", tc.flag, tc.raw) + } + if err.Error() != tc.want { + t.Errorf("expected %q, got %q", tc.want, err.Error()) + } + }) + } +} + func TestParseAll(t *testing.T) { ResetForTesting(func() { t.Error("bad parse") }) testParseAll(GetCommandLine(), t) diff --git a/float32.go b/float32.go index a243f81f..95ba4da2 100644 --- a/float32.go +++ b/float32.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- float32 Value type float32Value float32 @@ -12,8 +15,11 @@ func newFloat32Value(val float32, p *float32) *float32Value { func (f *float32Value) Set(s string) error { v, err := strconv.ParseFloat(s, 32) + if err != nil { + return errors.New("must be a number") + } *f = float32Value(v) - return err + return nil } func (f *float32Value) Type() string { diff --git a/float64.go b/float64.go index 04b5492a..67f0a1e0 100644 --- a/float64.go +++ b/float64.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- float64 Value type float64Value float64 @@ -12,8 +15,11 @@ func newFloat64Value(val float64, p *float64) *float64Value { func (f *float64Value) Set(s string) error { v, err := strconv.ParseFloat(s, 64) + if err != nil { + return errors.New("must be a number") + } *f = float64Value(v) - return err + return nil } func (f *float64Value) Type() string { diff --git a/int.go b/int.go index 1474b89d..d48704a4 100644 --- a/int.go +++ b/int.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- int Value type intValue int @@ -12,8 +15,11 @@ func newIntValue(val int, p *int) *intValue { func (i *intValue) Set(s string) error { v, err := strconv.ParseInt(s, 0, 64) + if err != nil { + return errors.New("must be an integer") + } *i = intValue(v) - return err + return nil } func (i *intValue) Type() string { diff --git a/int16.go b/int16.go index f1a01d05..880bc2a7 100644 --- a/int16.go +++ b/int16.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- int16 Value type int16Value int16 @@ -12,8 +15,11 @@ func newInt16Value(val int16, p *int16) *int16Value { func (i *int16Value) Set(s string) error { v, err := strconv.ParseInt(s, 0, 16) + if err != nil { + return errors.New("must be an integer") + } *i = int16Value(v) - return err + return nil } func (i *int16Value) Type() string { diff --git a/int32.go b/int32.go index 9b95944f..2e0944d6 100644 --- a/int32.go +++ b/int32.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- int32 Value type int32Value int32 @@ -12,8 +15,11 @@ func newInt32Value(val int32, p *int32) *int32Value { func (i *int32Value) Set(s string) error { v, err := strconv.ParseInt(s, 0, 32) + if err != nil { + return errors.New("must be an integer") + } *i = int32Value(v) - return err + return nil } func (i *int32Value) Type() string { diff --git a/int64.go b/int64.go index 0026d781..2f7ae522 100644 --- a/int64.go +++ b/int64.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- int64 Value type int64Value int64 @@ -12,8 +15,11 @@ func newInt64Value(val int64, p *int64) *int64Value { func (i *int64Value) Set(s string) error { v, err := strconv.ParseInt(s, 0, 64) + if err != nil { + return errors.New("must be an integer") + } *i = int64Value(v) - return err + return nil } func (i *int64Value) Type() string { diff --git a/int8.go b/int8.go index 4da92228..91fea1a6 100644 --- a/int8.go +++ b/int8.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- int8 Value type int8Value int8 @@ -12,8 +15,11 @@ func newInt8Value(val int8, p *int8) *int8Value { func (i *int8Value) Set(s string) error { v, err := strconv.ParseInt(s, 0, 8) + if err != nil { + return errors.New("must be an integer") + } *i = int8Value(v) - return err + return nil } func (i *int8Value) Type() string { diff --git a/uint.go b/uint.go index dcbc2b75..64bd9665 100644 --- a/uint.go +++ b/uint.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- uint Value type uintValue uint @@ -12,8 +15,11 @@ func newUintValue(val uint, p *uint) *uintValue { func (i *uintValue) Set(s string) error { v, err := strconv.ParseUint(s, 0, 64) + if err != nil { + return errors.New("must be a non-negative integer") + } *i = uintValue(v) - return err + return nil } func (i *uintValue) Type() string { diff --git a/uint16.go b/uint16.go index 7e9914ed..b086a368 100644 --- a/uint16.go +++ b/uint16.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- uint16 value type uint16Value uint16 @@ -12,8 +15,11 @@ func newUint16Value(val uint16, p *uint16) *uint16Value { func (i *uint16Value) Set(s string) error { v, err := strconv.ParseUint(s, 0, 16) + if err != nil { + return errors.New("must be a non-negative integer") + } *i = uint16Value(v) - return err + return nil } func (i *uint16Value) Type() string { diff --git a/uint32.go b/uint32.go index d8024539..ae56168d 100644 --- a/uint32.go +++ b/uint32.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- uint32 value type uint32Value uint32 @@ -12,8 +15,11 @@ func newUint32Value(val uint32, p *uint32) *uint32Value { func (i *uint32Value) Set(s string) error { v, err := strconv.ParseUint(s, 0, 32) + if err != nil { + return errors.New("must be a non-negative integer") + } *i = uint32Value(v) - return err + return nil } func (i *uint32Value) Type() string { diff --git a/uint64.go b/uint64.go index 86d8c7e6..ffbdc072 100644 --- a/uint64.go +++ b/uint64.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- uint64 Value type uint64Value uint64 @@ -12,8 +15,11 @@ func newUint64Value(val uint64, p *uint64) *uint64Value { func (i *uint64Value) Set(s string) error { v, err := strconv.ParseUint(s, 0, 64) + if err != nil { + return errors.New("must be a non-negative integer") + } *i = uint64Value(v) - return err + return nil } func (i *uint64Value) Type() string { diff --git a/uint8.go b/uint8.go index bb0e83c1..bca63f67 100644 --- a/uint8.go +++ b/uint8.go @@ -1,6 +1,9 @@ package pflag -import "strconv" +import ( + "errors" + "strconv" +) // -- uint8 Value type uint8Value uint8 @@ -12,8 +15,11 @@ func newUint8Value(val uint8, p *uint8) *uint8Value { func (i *uint8Value) Set(s string) error { v, err := strconv.ParseUint(s, 0, 8) + if err != nil { + return errors.New("must be a non-negative integer") + } *i = uint8Value(v) - return err + return nil } func (i *uint8Value) Type() string {