diff --git a/diff/diff.go b/diff/diff.go index 183eb1f..4bf74b2 100644 --- a/diff/diff.go +++ b/diff/diff.go @@ -43,8 +43,8 @@ func Diff(lhs, rhs interface{}, opts ...ConfigOpt) (Differ, error) { } func diff(c config, lhs, rhs interface{}, visited *visited) (Differ, error) { - lhsVal := reflect.ValueOf(lhs) - rhsVal := reflect.ValueOf(rhs) + lhsVal, lhs := indirectValueOf(lhs) + rhsVal, rhs := indirectValueOf(rhs) if d, ok := nilCheck(lhs, rhs); ok { return d, nil @@ -56,7 +56,7 @@ func diff(c config, lhs, rhs interface{}, visited *visited) (Differ, error) { return types{lhs, rhs}, ErrCyclic } - if valuesAreScalar(lhsVal, rhsVal) { + if valueIsScalar(lhsVal) && valueIsScalar(rhsVal) { return scalar{lhs, rhs}, nil } if lhsVal.Kind() != rhsVal.Kind() { @@ -75,15 +75,22 @@ func diff(c config, lhs, rhs interface{}, visited *visited) (Differ, error) { return types{lhs, rhs}, &ErrUnsupported{lhsVal.Type(), rhsVal.Type()} } -func valuesAreScalar(lhs, rhs reflect.Value) bool { - if lhs.Kind() == reflect.Struct || rhs.Kind() == reflect.Struct { - return false +func indirectValueOf(i interface{}) (reflect.Value, interface{}) { + v := reflect.Indirect(reflect.ValueOf(i)) + if !v.IsValid() || !v.CanInterface() { + return reflect.ValueOf(i), i } - if lhs.Kind() == reflect.Array || rhs.Kind() == reflect.Array { + + return v, v.Interface() +} + +func valueIsScalar(v reflect.Value) bool { + switch v.Kind() { + default: + return v.Type().Comparable() + case reflect.Struct, reflect.Array, reflect.Ptr, reflect.Chan: return false } - - return lhs.Type().Comparable() && rhs.Type().Comparable() } func nilCheck(lhs, rhs interface{}) (Differ, bool) { diff --git a/diff/diff_test.go b/diff/diff_test.go index a730ce2..d1aa575 100644 --- a/diff/diff_test.go +++ b/diff/diff_test.go @@ -2,6 +2,7 @@ package diff import ( "fmt" + "reflect" "strings" "testing" ) @@ -1128,3 +1129,32 @@ func TestIsSlice(t *testing.T) { t.Error("IsSlice(Diff(map{...}, map{...})) = false, expected true") } } + +func TestValueIsScalar(t *testing.T) { + for _, test := range []struct { + In interface{} + Expected bool + }{ + {int(42), true}, + {int8(23), true}, + {"foo", true}, + {true, true}, + {float32(1.2), true}, + {complex(5, -3), true}, + + {[]byte("foo"), false}, + {struct{}{}, false}, + {&struct{}{}, false}, + {[]int{1, 2, 3}, false}, + {[3]int{1, 2, 3}, false}, + {map[string]int{"foo": 22}, false}, + {func() {}, false}, + {make(chan struct{}), false}, + } { + v := reflect.ValueOf(test.In) + got := valueIsScalar(v) + if got != test.Expected { + t.Errorf("valueIsScalar(%T) = %v, expected %v", test.In, got, test.Expected) + } + } +}