diff --git a/flag_set.go b/flag_set.go index 4107631..b94ad87 100644 --- a/flag_set.go +++ b/flag_set.go @@ -241,16 +241,8 @@ func (fs *FlagSet) parseShortFlag(arg string, args []string) ([]string, error) { } func (fs *FlagSet) parseLongFlag(arg string, args []string) ([]string, error) { - var ( - name string - value string - ) - - if equals := strings.IndexRune(arg, '='); equals > 0 { - arg, value = arg[:equals], arg[equals+1:] - } - - name = strings.TrimPrefix(arg, "--") + name, value, eqFound := strings.Cut(arg, "=") + name = strings.TrimPrefix(name, "--") f := fs.findLongFlag(name) if f == nil { @@ -264,22 +256,24 @@ func (fs *FlagSet) parseLongFlag(arg string, args []string) ([]string, error) { } } - if value == "" { + if eqFound && f.isBoolFlag && value == "" { + value = "true" // `--debug=` amounts to `--debug=true` + } + + if value == "" && !eqFound { switch { case f.isBoolFlag: - value = "true" // `-b` or `--foo` default to true + value = "true" // `--foo` defaults to true if len(args) > 0 { if _, err := strconv.ParseBool(args[0]); err == nil { - value = args[0] // `-b true` or `--foo false` should also work + value = args[0] // `--foo false` should also work args = args[1:] } } - case !f.isBoolFlag && len(args) > 0: + case len(args) > 0: value, args = args[0], args[1:] - case !f.isBoolFlag && len(args) <= 0: - return nil, fmt.Errorf("missing value") default: - panic("unreachable") + return nil, fmt.Errorf("missing value") } } diff --git a/flag_set_test.go b/flag_set_test.go index 928c6d7..638022e 100644 --- a/flag_set_test.go +++ b/flag_set_test.go @@ -83,6 +83,7 @@ func TestFlagSet_Bool(t *testing.T) { {args: []string{"--help"}, wantX: false, wantY: true, wantErr: ff.ErrHelp}, {args: []string{"--xflag", "-h"}, wantX: true, wantY: true, wantErr: ff.ErrHelp}, {args: []string{"-y", "--help"}, wantX: false, wantY: false, wantErr: ff.ErrHelp}, + {args: []string{"--xflag=", "--help"}, wantX: true, wantY: true, wantErr: ff.ErrHelp}, } { t.Run(strings.Join(test.args, " "), func(t *testing.T) { fs := ff.NewFlagSet(t.Name()) diff --git a/parse_test.go b/parse_test.go index 62f9b3e..6eb79ea 100644 --- a/parse_test.go +++ b/parse_test.go @@ -233,6 +233,12 @@ func TestParse_FlagSet(t *testing.T) { Args: []string{`--str`, `foo`, `--help`, `-b`}, Want: fftest.Vars{S: "foo", B: false, WantParseErrorIs: ff.ErrHelp}, }, + { + Name: "--str= -a", + Constructors: []fftest.Constructor{fftest.CoreConstructor}, + Args: []string{`--str=`, `-a`}, + Want: fftest.Vars{S: "", A: true}, + }, { Name: "-s foo -f 1.23", Constructors: []fftest.Constructor{fftest.CoreConstructor},