diff --git a/flag_text.go b/flag_text.go index a9f7e1ba01..2a484f374a 100644 --- a/flag_text.go +++ b/flag_text.go @@ -5,21 +5,21 @@ import ( "strings" ) -type TextMarshalUnMarshaller interface { +type TextMarshalUnmarshaller interface { encoding.TextMarshaler encoding.TextUnmarshaler } -// TextFlag enables you to set types that satisfies [TextMarshalUnMarshaller] using flags such as log levels. -type TextFlag = FlagBase[TextMarshalUnMarshaller, StringConfig, TextValue] +// TextFlag enables you to set types that satisfies [TextMarshalUnmarshaller] using flags such as log levels. +type TextFlag = FlagBase[TextMarshalUnmarshaller, StringConfig, TextValue] type TextValue struct { - Value TextMarshalUnMarshaller + Value *TextMarshalUnmarshaller Config StringConfig } -func (v TextValue) String() string { - text, err := v.Value.MarshalText() +func (f TextValue) String() string { + text, err := (*f.Value).MarshalText() if err != nil { return "" } @@ -27,27 +27,31 @@ func (v TextValue) String() string { return string(text) } -func (v TextValue) Set(s string) error { - if v.Config.TrimSpace { - return v.Value.UnmarshalText([]byte(strings.TrimSpace(s))) +func (f TextValue) Set(s string) error { + if f.Config.TrimSpace { + s = strings.TrimSpace(s) } - return v.Value.UnmarshalText([]byte(s)) + return (*f.Value).UnmarshalText([]byte(s)) } -func (v TextValue) Get() any { - return v.Value +func (f TextValue) Get() any { + return *f.Value } -func (v TextValue) Create(t TextMarshalUnMarshaller, _ *TextMarshalUnMarshaller, c StringConfig) Value { +func (f TextValue) Create(v TextMarshalUnmarshaller, p *TextMarshalUnmarshaller, c StringConfig) Value { + if v != nil { + *p = v + } + return &TextValue{ - Value: t, + Value: p, Config: c, } } -func (v TextValue) ToString(t TextMarshalUnMarshaller) string { - text, err := t.MarshalText() +func (f TextValue) ToString(v TextMarshalUnmarshaller) string { + text, err := v.MarshalText() if err != nil { return "" } diff --git a/flag_text_test.go b/flag_text_test.go index 468299669f..a232c4458a 100644 --- a/flag_text_test.go +++ b/flag_text_test.go @@ -22,7 +22,12 @@ func (badMarshaller) MarshalText() ([]byte, error) { return nil, errors.New("bad") } +func ptr[T any](v T) *T { + return &v +} + func TestTextFlag(t *testing.T) { + tests := []struct { name string flag TextFlag @@ -33,17 +38,17 @@ func TestTextFlag(t *testing.T) { { name: "empty", flag: TextFlag{ - Name: "log-level", - Value: &slog.LevelVar{}, + Name: "log-level", + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), }, want: "INFO", }, { name: "info", flag: TextFlag{ - Name: "log-level", - Value: &slog.LevelVar{}, - Validator: func(v TextMarshalUnMarshaller) error { + Name: "log-level", + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), + Validator: func(v TextMarshalUnmarshaller) error { text, err := v.MarshalText() if err != nil { return err @@ -62,8 +67,8 @@ func TestTextFlag(t *testing.T) { { name: "debug", flag: TextFlag{ - Name: "log-level", - Value: &slog.LevelVar{}, + Name: "log-level", + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), }, args: []string{"--log-level", "debug"}, want: "DEBUG", @@ -71,9 +76,9 @@ func TestTextFlag(t *testing.T) { { name: "debug_with_trim", flag: TextFlag{ - Name: "log-level", - Value: &slog.LevelVar{}, - Config: StringConfig{TrimSpace: true}, + Name: "log-level", + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), + Config: StringConfig{TrimSpace: true}, }, args: []string{"--log-level", " debug "}, want: "DEBUG", @@ -81,8 +86,8 @@ func TestTextFlag(t *testing.T) { { name: "invalid", flag: TextFlag{ - Name: "log-level", - Value: &slog.LevelVar{}, + Name: "log-level", + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), }, args: []string{"--log-level", "invalid"}, want: "INFO", @@ -91,12 +96,45 @@ func TestTextFlag(t *testing.T) { { name: "bad_marshaller", flag: TextFlag{ - Name: "text", - Value: &badMarshaller{}, + Name: "text", + Value: &badMarshaller{}, + Destination: ptr[TextMarshalUnmarshaller](&badMarshaller{}), }, args: []string{"--text", "foo"}, wantErr: true, }, + { + name: "default", + flag: TextFlag{ + Name: "log-level", + Value: func() *slog.LevelVar { + var l slog.LevelVar + + l.Set(slog.LevelWarn) + + return &l + }(), + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), + }, + args: []string{}, + want: "WARN", + }, + { + name: "override_default", + flag: TextFlag{ + Name: "log-level", + Value: func() *slog.LevelVar { + var l slog.LevelVar + + l.Set(slog.LevelWarn) + + return &l + }(), + Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}), + }, + args: []string{"--log-level", "error"}, + want: "ERROR", + }, } t.Parallel() @@ -114,12 +152,16 @@ func TestTextFlag(t *testing.T) { require.False(t, (err != nil) && !tt.wantErr, tt.name) + if tt.flag.Value != nil { + assert.Equal(t, tt.want, tt.flag.GetDefaultText()) + } + if tt.wantErr { - require.Equal(t, tt.flag.GetDefaultText(), tt.want) + return } - assert.Equal(t, set.Lookup(tt.flag.Name).Value.String(), tt.want) - assert.Equal(t, tt.flag.GetDefaultText(), tt.want) + assert.Equal(t, tt.want, set.Lookup(tt.flag.Name).Value.String()) + }) } }