diff --git a/assert/assertions.go b/assert/assertions.go index 66df7ab6d..a23c9b2bd 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -15,9 +15,6 @@ import ( "time" "unicode" "unicode/utf8" - - "github.com/davecgh/go-spew/spew" - "github.com/pmezard/go-difflib/difflib" ) //go:generate go run ../_codegen/main.go -output-package=assert -template=assertion_format.go.tmpl @@ -51,7 +48,7 @@ func ObjectsAreEqual(expected, actual interface{}) bool { } return bytes.Equal(exp, act) } - return reflect.DeepEqual(expected, actual) + return compare(expected, actual) } @@ -69,7 +66,7 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool { expectedValue := reflect.ValueOf(expected) if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { // Attempt comparison after type conversion - return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) + return compare(expectedValue.Convert(actualType).Interface(), actual) } return false @@ -90,7 +87,7 @@ func CallerInfo() []string { ok := false name := "" - callers := []string{} + var callers []string for i := 0; ; i++ { pc, file, line, ok = runtime.Caller(i) if !ok { @@ -449,7 +446,7 @@ func isEmpty(object interface{}) bool { // for all other types, compare against the zero value default: zero := reflect.Zero(objValue.Type()) - return reflect.DeepEqual(object, zero.Interface()) + return compare(object, zero.Interface()) } } @@ -1179,7 +1176,7 @@ func matchRegexp(rx interface{}, str interface{}) bool { r = regexp.MustCompile(fmt.Sprint(rx)) } - return (r.FindStringIndex(fmt.Sprint(str)) != nil) + return r.FindStringIndex(fmt.Sprint(str)) != nil } @@ -1224,7 +1221,7 @@ func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } - if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + if i != nil && !compare(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...) } return true @@ -1235,7 +1232,7 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { if h, ok := t.(tHelper); ok { h.Helper() } - if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { + if i == nil || compare(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...) } return true @@ -1297,51 +1294,6 @@ func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{ return Equal(t, expectedJSONAsInterface, actualJSONAsInterface, msgAndArgs...) } -func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { - t := reflect.TypeOf(v) - k := t.Kind() - - if k == reflect.Ptr { - t = t.Elem() - k = t.Kind() - } - return t, k -} - -// diff returns a diff of both values as long as both are of the same type and -// are a struct, map, slice or array. Otherwise it returns an empty string. -func diff(expected interface{}, actual interface{}) string { - if expected == nil || actual == nil { - return "" - } - - et, ek := typeAndKind(expected) - at, _ := typeAndKind(actual) - - if et != at { - return "" - } - - if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { - return "" - } - - e := spewConfig.Sdump(expected) - a := spewConfig.Sdump(actual) - - diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ - A: difflib.SplitLines(e), - B: difflib.SplitLines(a), - FromFile: "Expected", - FromDate: "", - ToFile: "Actual", - ToDate: "", - Context: 1, - }) - - return "\n\nDiff:\n" + diff -} - // validateEqualArgs checks whether provided arguments can be safely used in the // Equal/NotEqual functions. func validateEqualArgs(expected, actual interface{}) error { @@ -1358,13 +1310,6 @@ func isFunction(arg interface{}) bool { return reflect.TypeOf(arg).Kind() == reflect.Func } -var spewConfig = spew.ConfigState{ - Indent: " ", - DisablePointerAddresses: true, - DisableCapacities: true, - SortKeys: true, -} - type tHelper interface { Helper() } diff --git a/assert/assertions_test.go b/assert/assertions_test.go index d8e0d7435..9be097fd4 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -258,7 +258,7 @@ func TestEqualFormatting(t *testing.T) { want string }{ {equalWant: "want", equalGot: "got", want: "\tassertions.go:[0-9]+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: \"want\"\n\t\t\t \tactual : \"got\"\n"}, - {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: "\tassertions.go:[0-9]+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: \"want\"\n\t\t\t \tactual : \"got\"\n\t\t\tMessages: \thello, world!\n"}, + {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: `\tassertions.go:\d+: \n\t\t\tError Trace:\t\n\t\t\tError: \tNot equal: \n\t\t\t \texpected: "want"\n\t\t\t \tactual : "got"\n\t\t\tMessages: \thello, world!\n`}, } { mockT := &bufferT{} Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...) @@ -1400,11 +1400,9 @@ func TestDiff(t *testing.T) { Diff: --- Expected +++ Actual -@@ -1,3 +1,3 @@ - (struct { foo string }) { -- foo: (string) (len=5) "hello" -+ foo: (string) (len=3) "bar" - } +root.foo: + -: "hello" + +: "bar" ` actual := diff( struct{ foo string }{"hello"}, @@ -1417,14 +1415,9 @@ Diff: Diff: --- Expected +++ Actual -@@ -2,5 +2,5 @@ - (int) 1, -- (int) 2, - (int) 3, -- (int) 4 -+ (int) 5, -+ (int) 7 - } +{[]int}: + -: []int{1, 2, 3, 4} + +: []int{1, 3, 5, 7} ` actual = diff( []int{1, 2, 3, 4}, @@ -1437,13 +1430,9 @@ Diff: Diff: --- Expected +++ Actual -@@ -2,4 +2,4 @@ - (int) 1, -- (int) 2, -- (int) 3 -+ (int) 3, -+ (int) 5 - } +{[]int}: + -: []int{1, 2, 3} + +: []int{1, 3, 5} ` actual = diff( []int{1, 2, 3, 4}[0:3], @@ -1456,16 +1445,18 @@ Diff: Diff: --- Expected +++ Actual -@@ -1,6 +1,6 @@ - (map[string]int) (len=4) { -- (string) (len=4) "four": (int) 4, -+ (string) (len=4) "five": (int) 5, - (string) (len=3) "one": (int) 1, -- (string) (len=5) "three": (int) 3, -- (string) (len=3) "two": (int) 2 -+ (string) (len=5) "seven": (int) 7, -+ (string) (len=5) "three": (int) 3 - } +{map[string]int}["five"]: + -: + +: 5 +{map[string]int}["four"]: + -: 4 + +: +{map[string]int}["seven"]: + -: + +: 7 +{map[string]int}["two"]: + -: 2 + +: ` actual = diff( @@ -1555,7 +1546,7 @@ func TestBytesEqual(t *testing.T) { {nil, make([]byte, 0)}, } for i, c := range cases { - Equal(t, reflect.DeepEqual(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1) + Equal(t, compare(c.a, c.b), ObjectsAreEqual(c.a, c.b), "case %d failed", i+1) } } diff --git a/assert/compare.go b/assert/compare.go new file mode 100644 index 000000000..b1c2624d9 --- /dev/null +++ b/assert/compare.go @@ -0,0 +1,151 @@ +package assert + +import ( + "reflect" + + "github.com/google/go-cmp/cmp" +) + +// compare compares two objects +func compare(expected, actual interface{}) bool { + return cmp.Equal(expected, actual, compareOptions(expected, actual)...) +} + +// diff returns a diff of both values as long as both are of the same type and +// are a struct, map, slice or array. Otherwise it returns an empty string. +func diff(expected, actual interface{}) string { + if expected == nil || actual == nil { + return "" + } + + et, ek := typeAndKind(expected) + at, _ := typeAndKind(actual) + + if et != at { + return "" + } + + if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { + return "" + } + + diff := cmp.Diff(expected, actual, compareOptions(expected, actual)...) + if diff != "" { + diff = "\n\nDiff:\n--- Expected\n+++ Actual\n" + diff + } + return diff +} + +// compareOptions are cmp.Options used for cmp.Equal and cmp.Diff to compare +// two general objects for testing purposes +func compareOptions(expected, actual interface{}) cmp.Options { + return cmp.Options{ + deepAllowUnexported(expected, actual), + compareIdenticalPointers, + } +} + +// deepAllowUnexported returns option for cmp.Equal or cmp.Diff in which +// all unexported fields in the two compared types (recursively) are +// allowed. +// Code from https://github.com/google/go-cmp/issues/40 with modification +// to work with cyclic struct +func deepAllowUnexported(vs ...interface{}) cmp.Option { + var ( + // allUnexported is a set of types to be added to the unexported list + allUnexported = make(map[reflect.Type]bool) + // visited are list of pointer which are visited during the recursive collection + // of the referenced types. + // It is used to detect cycles and prevent infinite recursion. + visited = make(map[uintptr]bool) + ) + + // Collect all types from all given objects + for _, v := range vs { + structTypes(reflect.ValueOf(v), allUnexported, visited) + } + + // Collect the referenced types + var types []interface{} + for t := range allUnexported { + types = append(types, reflect.New(t).Elem().Interface()) + } + + // Return cmp option which allows all unexported fields in all the collected types + return cmp.AllowUnexported(types...) +} + +// structTypes is a recursive search for all referenced types from a given object. +// It searches recursively in all the given object fields and references, and put the +// collected type in the `m` set. +// It uses the `visited` set to detect cycles and prevent infinite recursion +func structTypes(v reflect.Value, m map[reflect.Type]bool, visited map[uintptr]bool) { + if !v.IsValid() { + return + } + + // dive in according to the kind of the given object + switch v.Kind() { + case reflect.Ptr: + if v.IsNil() { + return + } + // prevent infinite recursion + if visited[v.Elem().UnsafeAddr()] { + return + } + // remember jumping to a pointed address + visited[v.Elem().UnsafeAddr()] = true + structTypes(v.Elem(), m, visited) + case reflect.Interface: + if v.IsNil() { + return + } + // search into the object that implement the interface + structTypes(v.Elem(), m, visited) + case reflect.Slice, reflect.Array: + // recursively search in all the slice/array objects + for i := 0; i < v.Len(); i++ { + structTypes(v.Index(i), m, visited) + } + case reflect.Map: + // recursively search in all the map values + for _, k := range v.MapKeys() { + structTypes(v.MapIndex(k), m, visited) + } + case reflect.Struct: + // add the type to the collected types. + m[v.Type()] = true + // recursively search in all the struct fields + for i := 0; i < v.NumField(); i++ { + structTypes(v.Field(i), m, visited) + } + } +} + +// compareIdenticalPointers is a cmp option that returns true if the two compared +// objects are pointers and are pointing on the same thing. +var compareIdenticalPointers = cmp.FilterPath(func(p cmp.Path) bool { + // Filter for pointer kinds only. + t := p.Last().Type() + return t != nil && t.Kind() == reflect.Ptr +}, cmp.FilterValues(func(x, y interface{}) bool { + // Filter for pointer values that are identical. + vx := reflect.ValueOf(x) + vy := reflect.ValueOf(y) + return vx.IsValid() && vy.IsValid() && vx.Pointer() == vy.Pointer() +}, cmp.Comparer(func(_, _ interface{}) bool { + // Consider them equal no matter what. + return true +}))) + +func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) { + t := reflect.TypeOf(v) + k := t.Kind() + + if k == reflect.Ptr { + t = t.Elem() + k = t.Kind() + } + return t, k +}