From 7fd817b182f2683be5972481eb580f4a65e01b25 Mon Sep 17 00:00:00 2001 From: Anand Gaitonde Date: Tue, 16 Oct 2018 22:07:49 -0700 Subject: [PATCH] Allow custom flags to read '-' if they are a value validator [https://www.pivotaltracker.com/story/show/127131179] [https://www.pivotaltracker.com/story/show/160742110] --- convert.go | 9 +++++++++ option.go | 34 ++++++++++++++++++++++++++++++++ parser.go | 4 ++-- parser_test.go | 53 ++++++++++++++++++++++++++++++++++++++++---------- 4 files changed, 88 insertions(+), 12 deletions(-) diff --git a/convert.go b/convert.go index 984aac8..cda29b2 100644 --- a/convert.go +++ b/convert.go @@ -28,6 +28,15 @@ type Unmarshaler interface { UnmarshalFlag(value string) error } +// ValueValidator is the interface implemented by types that can validate a +// flag argument themselves. The provided value is directly passed from the +// command line. +type ValueValidator interface { + // IsValidValue returns an error if the provided string value is valid for + // the flag. + IsValidValue(value string) error +} + func getBase(options multiTag, base int) (int, error) { sbase := options.Get("base") diff --git a/option.go b/option.go index 8e306d9..5cebb54 100644 --- a/option.go +++ b/option.go @@ -389,6 +389,30 @@ func (option *Option) isUnmarshaler() Unmarshaler { return nil } +func (option *Option) isValueValidator() ValueValidator { + v := option.value + + for { + if !v.CanInterface() { + break + } + + i := v.Interface() + + if u, ok := i.(ValueValidator); ok { + return u + } + + if !v.CanAddr() { + break + } + + v = v.Addr() + } + + return nil +} + func (option *Option) isBool() bool { tp := option.value.Type() @@ -507,3 +531,13 @@ func (option *Option) shortAndLongName() string { return ret.String() } + +func (option *Option) isValidValue(arg string) error { + if validator := option.isValueValidator(); validator != nil { + return validator.IsValidValue(arg) + } + if argumentIsOption(arg) && !(option.isSignedNumber() && len(arg) > 1 && arg[0] == '-' && arg[1] >= '0' && arg[1] <= '9') { + return fmt.Errorf("expected argument for flag `%s', but got option `%s'", option, arg) + } + return nil +} diff --git a/parser.go b/parser.go index 54816a6..a5347b0 100644 --- a/parser.go +++ b/parser.go @@ -532,8 +532,8 @@ func (p *Parser) parseOption(s *parseState, name string, option *Option, canarg } else { arg = s.pop() - if argumentIsOption(arg) && !(option.isSignedNumber() && len(arg) > 1 && arg[0] == '-' && arg[1] >= '0' && arg[1] <= '9') { - return newErrorf(ErrExpectedArgument, "expected argument for flag `%s', but got option `%s'", option, arg) + if validationErr := option.isValidValue(arg); validationErr != nil { + return newErrorf(ErrExpectedArgument, validationErr.Error()) } else if p.Options&PassDoubleDash != 0 && arg == "--" { return newErrorf(ErrExpectedArgument, "expected argument for flag `%s', but got double dash `--'", option) } diff --git a/parser_test.go b/parser_test.go index f0c768d..bd2f464 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1,6 +1,7 @@ package flags import ( + "errors" "fmt" "os" "reflect" @@ -363,6 +364,22 @@ func TestEnvDefaults(t *testing.T) { } } +type CustomFlag struct { + Value string +} + +func (c *CustomFlag) UnmarshalFlag(s string) error { + c.Value = s + return nil +} + +func (c *CustomFlag) IsValidValue(s string) error { + if !(s == "-1" || s == "-foo") { + return errors.New("invalid flag value") + } + return nil +} + func TestOptionAsArgument(t *testing.T) { var tests = []struct { args []string @@ -419,30 +436,46 @@ func TestOptionAsArgument(t *testing.T) { rest: []string{"-", "-"}, }, { - // Accept arguments which start with '-' if the next character is a digit, for number options only + // Accept arguments which start with '-' if the next character is a digit args: []string{"--int-slice", "-3"}, }, { - // Accept arguments which start with '-' if the next character is a digit, for number options only + // Accept arguments which start with '-' if the next character is a digit args: []string{"--int16", "-3"}, }, { - // Accept arguments which start with '-' if the next character is a digit, for number options only + // Accept arguments which start with '-' if the next character is a digit args: []string{"--float32", "-3.2"}, }, { - // Accept arguments which start with '-' if the next character is a digit, for number options only + // Accept arguments which start with '-' if the next character is a digit args: []string{"--float32ptr", "-3.2"}, }, + { + // Accept arguments for values that pass the IsValidValue fuction for value validators + args: []string{"--custom-flag", "-foo"}, + }, + { + // Accept arguments for values that pass the IsValidValue fuction for value validators + args: []string{"--custom-flag", "-1"}, + }, + { + // Rejects arguments for values that fail the IsValidValue fuction for value validators + args: []string{"--custom-flag", "-2"}, + expectError: true, + errType: ErrExpectedArgument, + errMsg: "invalid flag value", + }, } var opts struct { - StringSlice []string `long:"string-slice"` - IntSlice []int `long:"int-slice"` - Int16 int16 `long:"int16"` - Float32 float32 `long:"float32"` - Float32Ptr *float32 `long:"float32ptr"` - OtherOption bool `long:"other-option" short:"o"` + StringSlice []string `long:"string-slice"` + IntSlice []int `long:"int-slice"` + Int16 int16 `long:"int16"` + Float32 float32 `long:"float32"` + Float32Ptr *float32 `long:"float32ptr"` + OtherOption bool `long:"other-option" short:"o"` + Custom CustomFlag `long:"custom-flag" short:"c"` } for _, test := range tests {