Skip to content

Commit

Permalink
fix all any casting in unmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
NoBypass committed Oct 2, 2024
1 parent a325034 commit 424b8c8
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 39 deletions.
86 changes: 47 additions & 39 deletions marshal/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,20 @@ func (m *Marshaler) Unmarshal(src, dest any) (err error) {
}

func (m *Marshaler) unmarshal(src, dest reflect.Value) error {
switch src.Kind() {
if src.Kind() == reflect.Interface {
src = src.Elem()
}

switch dest.Kind() {
case reflect.Bool:
return m.simpleValueDecoder(src, dest)
case reflect.String:
if dest.Type() == reflect.TypeOf(time.Time{}) {
return m.timeDecoder(src, dest)
} else if dest.Type().String() == "time.Duration" {
return m.durationDecoder(src, dest)
}
return m.simpleValueDecoder(src, dest)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
if src.Type().AssignableTo(dest.Type()) {
return m.simpleValueDecoder(src, dest)
} else if dest.Type().String() == "time.Duration" {
return m.durationDecoder(src, dest)
}
return m.numberDecoder(src, dest)
case reflect.Slice:
Expand All @@ -45,9 +46,23 @@ func (m *Marshaler) unmarshal(src, dest reflect.Value) error {
return m.interfaceDecoder(src, dest)
case reflect.Map:
return m.mapDecoder(src, dest)
case reflect.Struct:
if dest.Type() == reflect.TypeOf(time.Time{}) {
return m.timeDecoder(src, dest)
}
return m.structDecoder(src, dest)
case reflect.Ptr:
return m.pointerDecoder(src, dest)
default:
return errs.ErrUnmarshal.Withf("cannot unmarshal %s", src.Type())
return errs.ErrUnmarshal.Withf("cannot unmarshal %s", dest.Type())
}
}

func (m *Marshaler) pointerDecoder(src, dest reflect.Value) error {
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
return m.unmarshal(src, dest.Elem())
}

func (m *Marshaler) simpleValueDecoder(src, dest reflect.Value) error {
Expand Down Expand Up @@ -110,46 +125,39 @@ func (m *Marshaler) sliceDecoder(src, dest reflect.Value) error {
}

func (m *Marshaler) interfaceDecoder(src, dest reflect.Value) error {
if src.Elem().Type().AssignableTo(dest.Type()) {
dest.Set(src.Elem())
if src.Type().AssignableTo(dest.Type()) {
dest.Set(src)
return nil
} else {
return errs.ErrUnmarshal.Withf("cannot unmarshal interface")
}
return errs.ErrUnmarshal.Withf("cannot unmarshal %s to %s", src.Type(), dest.Type())
}

func (m *Marshaler) mapDecoder(src, dest reflect.Value) error {
if dest.Kind() == reflect.Ptr {
if dest.IsNil() {
dest.Set(reflect.New(dest.Type().Elem()))
}
dest = dest.Elem()
dest.Set(reflect.MakeMap(dest.Type()))
for _, key := range src.MapKeys() {
value := src.MapIndex(key)
dest.SetMapIndex(key, value)
}
return nil
}

func (m *Marshaler) structDecoder(src, dest reflect.Value) error {
for i := 0; i < dest.NumField(); i++ {
field := dest.Type().Field(i)
tag := m.tagOf(field)

if dest.Kind() == reflect.Map {
dest.Set(reflect.MakeMap(dest.Type()))
for _, key := range src.MapKeys() {
value := src.MapIndex(key)
dest.SetMapIndex(key, value)
mapVal := src.MapIndex(reflect.ValueOf(tag))
if !mapVal.IsValid() {
continue
}
} else if dest.Kind() == reflect.Struct {
for i := 0; i < dest.Type().NumField(); i++ {
field := dest.Type().Field(i)
tag := m.tagOf(field)

mapVal := src.MapIndex(reflect.ValueOf(tag))
if !mapVal.IsValid() {
continue
}

fieldVal := dest.Field(i)
if !fieldVal.CanSet() {
continue
}

if err := m.Unmarshal(mapVal.Interface(), fieldVal.Addr().Interface()); err != nil {
return err
}

fieldVal := dest.Field(i)
if !fieldVal.CanSet() {
continue
}

if err := m.unmarshal(mapVal, fieldVal); err != nil {
return err
}
}

Expand Down
11 changes: 11 additions & 0 deletions marshal/unmarshal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,4 +159,15 @@ func TestMarshaler_Unmarshal(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []string{"a", "b"}, s)
})
t.Run("map as any to struct ", func(t *testing.T) {
type testStruct struct {
Test string
Num int
}
var s testStruct
var in any = map[string]any{"Test": "test", "Num": 42}
err := m.Unmarshal(in, &s)
assert.NoError(t, err)
assert.Equal(t, testStruct{"test", 42}, s)
})
}

0 comments on commit 424b8c8

Please # to comment.