diff --git a/assert/assertions.go b/assert/assertions.go index b362b4a29..6b70d3612 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -102,11 +102,20 @@ func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { field := expectedType.Field(i) isExported := field.PkgPath == "" // should use field.IsExported() but it's not available in Go 1.16.5 if isExported { + expectedField := expectedValue.Field(i) + actualField := actualValue.Field(i) + var equal bool - if field.Type.Kind() == reflect.Struct { - equal = ObjectsExportedFieldsAreEqual(expectedValue.Field(i).Interface(), actualValue.Field(i).Interface()) - } else { - equal = ObjectsAreEqualValues(expectedValue.Field(i).Interface(), actualValue.Field(i).Interface()) + switch field.Type.Kind() { + case reflect.Struct: + equal = ObjectsExportedFieldsAreEqual(expectedField.Interface(), actualField.Interface()) + case reflect.Pointer: + if expectedField.IsNil() || actualField.IsNil() { + return expectedField == actualField + } + equal = ObjectsExportedFieldsAreEqual(expectedField.Elem().Interface(), actualField.Elem().Interface()) + default: + equal = ObjectsAreEqualValues(expectedField.Interface(), actualField.Interface()) } if !equal { diff --git a/assert/assertions_test.go b/assert/assertions_test.go index 5eaf2af8d..3dd4ed0d2 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -166,6 +166,10 @@ func TestObjectsExportedFieldsAreEqual(t *testing.T) { foo interface{} } + type S3 struct { + ExportedPointer *Nested + } + cases := []struct { expected interface{} actual interface{} @@ -180,6 +184,11 @@ func TestObjectsExportedFieldsAreEqual(t *testing.T) { {S{1, Nested{2, 3}, 4, Nested{5, 6}}, S{1, Nested{"a", 3}, 4, Nested{5, 6}}, false}, {S{1, Nested{2, 3}, 4, Nested{5, 6}}, S2{1}, false}, {1, S{1, Nested{2, 3}, 4, Nested{5, 6}}, false}, + {S3{&Nested{2, 3}}, S3{&Nested{2, 3}}, true}, + {S3{&Nested{2, 3}}, S3{&Nested{2, 4}}, true}, + {S3{&Nested{2, 3}}, S3{&Nested{"a", 3}}, false}, + {S3{&Nested{2, 3}}, S3{}, false}, + {S3{}, S3{}, true}, } for _, c := range cases {