Skip to content

Commit

Permalink
fix: address review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
somebadcode committed Feb 11, 2025
1 parent de9fde3 commit 0876e6b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
36 changes: 20 additions & 16 deletions flag_text.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,49 +5,53 @@ import (
"strings"
)

type TextMarshalUnMarshaller interface {
type TextMarshalUnmarshaller interface {
encoding.TextMarshaler
encoding.TextUnmarshaler
}

// TextFlag enables you to set types that satisfies [TextMarshalUnMarshaller] using flags such as log levels.
type TextFlag = FlagBase[TextMarshalUnMarshaller, StringConfig, TextValue]
// TextFlag enables you to set types that satisfies [TextMarshalUnmarshaller] using flags such as log levels.
type TextFlag = FlagBase[TextMarshalUnmarshaller, StringConfig, TextValue]

type TextValue struct {
Value TextMarshalUnMarshaller
Value *TextMarshalUnmarshaller
Config StringConfig
}

func (v TextValue) String() string {
text, err := v.Value.MarshalText()
func (f TextValue) String() string {
text, err := (*f.Value).MarshalText()
if err != nil {
return ""
}

return string(text)
}

func (v TextValue) Set(s string) error {
if v.Config.TrimSpace {
return v.Value.UnmarshalText([]byte(strings.TrimSpace(s)))
func (f TextValue) Set(s string) error {
if f.Config.TrimSpace {
s = strings.TrimSpace(s)
}

return v.Value.UnmarshalText([]byte(s))
return (*f.Value).UnmarshalText([]byte(s))
}

func (v TextValue) Get() any {
return v.Value
func (f TextValue) Get() any {
return *f.Value
}

func (v TextValue) Create(t TextMarshalUnMarshaller, _ *TextMarshalUnMarshaller, c StringConfig) Value {
func (f TextValue) Create(v TextMarshalUnmarshaller, p *TextMarshalUnmarshaller, c StringConfig) Value {
if v != nil {
*p = v
}

return &TextValue{
Value: t,
Value: p,
Config: c,
}
}

func (v TextValue) ToString(t TextMarshalUnMarshaller) string {
text, err := t.MarshalText()
func (f TextValue) ToString(v TextMarshalUnmarshaller) string {
text, err := v.MarshalText()
if err != nil {
return ""
}
Expand Down
76 changes: 59 additions & 17 deletions flag_text_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ func (badMarshaller) MarshalText() ([]byte, error) {
return nil, errors.New("bad")
}

func ptr[T any](v T) *T {
return &v
}

func TestTextFlag(t *testing.T) {

tests := []struct {
name string
flag TextFlag
Expand All @@ -33,17 +38,17 @@ func TestTextFlag(t *testing.T) {
{
name: "empty",
flag: TextFlag{
Name: "log-level",
Value: &slog.LevelVar{},
Name: "log-level",
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
},
want: "INFO",
},
{
name: "info",
flag: TextFlag{
Name: "log-level",
Value: &slog.LevelVar{},
Validator: func(v TextMarshalUnMarshaller) error {
Name: "log-level",
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
Validator: func(v TextMarshalUnmarshaller) error {
text, err := v.MarshalText()
if err != nil {
return err
Expand All @@ -62,27 +67,27 @@ func TestTextFlag(t *testing.T) {
{
name: "debug",
flag: TextFlag{
Name: "log-level",
Value: &slog.LevelVar{},
Name: "log-level",
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
},
args: []string{"--log-level", "debug"},
want: "DEBUG",
},
{
name: "debug_with_trim",
flag: TextFlag{
Name: "log-level",
Value: &slog.LevelVar{},
Config: StringConfig{TrimSpace: true},
Name: "log-level",
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
Config: StringConfig{TrimSpace: true},
},
args: []string{"--log-level", " debug "},
want: "DEBUG",
},
{
name: "invalid",
flag: TextFlag{
Name: "log-level",
Value: &slog.LevelVar{},
Name: "log-level",
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
},
args: []string{"--log-level", "invalid"},
want: "INFO",
Expand All @@ -91,12 +96,45 @@ func TestTextFlag(t *testing.T) {
{
name: "bad_marshaller",
flag: TextFlag{
Name: "text",
Value: &badMarshaller{},
Name: "text",
Value: &badMarshaller{},
Destination: ptr[TextMarshalUnmarshaller](&badMarshaller{}),
},
args: []string{"--text", "foo"},
wantErr: true,
},
{
name: "default",
flag: TextFlag{
Name: "log-level",
Value: func() *slog.LevelVar {
var l slog.LevelVar

l.Set(slog.LevelWarn)

return &l
}(),
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
},
args: []string{},
want: "WARN",
},
{
name: "override_default",
flag: TextFlag{
Name: "log-level",
Value: func() *slog.LevelVar {
var l slog.LevelVar

l.Set(slog.LevelWarn)

return &l
}(),
Destination: ptr[TextMarshalUnmarshaller](&slog.LevelVar{}),
},
args: []string{"--log-level", "error"},
want: "ERROR",
},
}

t.Parallel()
Expand All @@ -114,12 +152,16 @@ func TestTextFlag(t *testing.T) {

require.False(t, (err != nil) && !tt.wantErr, tt.name)

if tt.flag.Value != nil {
assert.Equal(t, tt.want, tt.flag.GetDefaultText())
}

if tt.wantErr {
require.Equal(t, tt.flag.GetDefaultText(), tt.want)
return
}

assert.Equal(t, set.Lookup(tt.flag.Name).Value.String(), tt.want)
assert.Equal(t, tt.flag.GetDefaultText(), tt.want)
assert.Equal(t, tt.want, set.Lookup(tt.flag.Name).Value.String())

})
}
}

0 comments on commit 0876e6b

Please # to comment.