diff --git a/assert/assertions.go b/assert/assertions.go index bfb640abc..4e91332bb 100644 --- a/assert/assertions.go +++ b/assert/assertions.go @@ -502,7 +502,13 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b h.Helper() } - if !samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + return Fail(t, "Both arguments must be pointers", msgAndArgs...) + } + + if !same { + // both are pointers but not the same type & pointing to the same address return Fail(t, fmt.Sprintf("Not same: \n"+ "expected: %p %#v\n"+ "actual : %p %#v", expected, expected, actual, actual), msgAndArgs...) @@ -522,7 +528,13 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} h.Helper() } - if samePointers(expected, actual) { + same, ok := samePointers(expected, actual) + if !ok { + //fails when the arguments are not pointers + return !(Fail(t, "Both arguments must be pointers", msgAndArgs...)) + } + + if same { return Fail(t, fmt.Sprintf( "Expected and actual point to the same object: %p %#v", expected, expected), msgAndArgs...) @@ -530,21 +542,23 @@ func NotSame(t TestingT, expected, actual interface{}, msgAndArgs ...interface{} return true } -// samePointers compares two generic interface objects and returns whether -// they point to the same object -func samePointers(first, second interface{}) bool { +// samePointers checks if two generic interface objects are pointers of the same +// type pointing to the same object. It returns two values: same indicating if +// they are the same type and point to the same object, and ok indicating that +// both inputs are pointers. +func samePointers(first, second interface{}) (same bool, ok bool) { firstPtr, secondPtr := reflect.ValueOf(first), reflect.ValueOf(second) if firstPtr.Kind() != reflect.Ptr || secondPtr.Kind() != reflect.Ptr { - return false + return false, false //not both are pointers } firstType, secondType := reflect.TypeOf(first), reflect.TypeOf(second) if firstType != secondType { - return false + return false, true // both are pointers, but of different types } // compare pointer addresses - return first == second + return first == second, true } // formatUnequalValues takes two values of arbitrary types and returns string diff --git a/assert/assertions_test.go b/assert/assertions_test.go index ca125998c..9f859fa8d 100644 --- a/assert/assertions_test.go +++ b/assert/assertions_test.go @@ -655,39 +655,59 @@ func Test_samePointers(t *testing.T) { second interface{} } tests := []struct { - name string - args args - assertion BoolAssertionFunc + name string + args args + same BoolAssertionFunc + ok BoolAssertionFunc }{ { - name: "1 != 2", - args: args{first: 1, second: 2}, - assertion: False, + name: "1 != 2", + args: args{first: 1, second: 2}, + same: False, + ok: False, + }, + { + name: "1 != 1 (not same ptr)", + args: args{first: 1, second: 1}, + same: False, + ok: False, + }, + { + name: "ptr(1) == ptr(1)", + args: args{first: p, second: p}, + same: True, + ok: True, }, { - name: "1 != 1 (not same ptr)", - args: args{first: 1, second: 1}, - assertion: False, + name: "int(1) != float32(1)", + args: args{first: int(1), second: float32(1)}, + same: False, + ok: False, }, { - name: "ptr(1) == ptr(1)", - args: args{first: p, second: p}, - assertion: True, + name: "array != slice", + args: args{first: [2]int{1, 2}, second: []int{1, 2}}, + same: False, + ok: False, }, { - name: "int(1) != float32(1)", - args: args{first: int(1), second: float32(1)}, - assertion: False, + name: "non-pointer vs pointer (1 != ptr(2))", + args: args{first: 1, second: p}, + same: False, + ok: False, }, { - name: "array != slice", - args: args{first: [2]int{1, 2}, second: []int{1, 2}}, - assertion: False, + name: "pointer vs non-pointer (ptr(2) != 1)", + args: args{first: p, second: 1}, + same: False, + ok: False, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tt.assertion(t, samePointers(tt.args.first, tt.args.second)) + same, ok := samePointers(tt.args.first, tt.args.second) + tt.same(t, same) + tt.ok(t, ok) }) } }