diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java index 71041e2eb..9ca4cd267 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java @@ -38,8 +38,10 @@ public class EqualsAvoidsNull extends Recipe { private static final String JAVA_LANG_STRING = "java.lang.String"; + private static final String JAVA_LANG_OBJECT = "java.lang.Object"; - private static final MethodMatcher EQUALS = new MethodMatcher(JAVA_LANG_STRING + " equals(java.lang.Object)"); + private static final MethodMatcher EQUALS_STRING = new MethodMatcher(JAVA_LANG_STRING + " equals(" + JAVA_LANG_OBJECT + ")"); + private static final MethodMatcher EQUALS_OBJECT = new MethodMatcher(JAVA_LANG_OBJECT + " equals(" + JAVA_LANG_OBJECT + ")"); private static final MethodMatcher EQUALS_IGNORE_CASE = new MethodMatcher(JAVA_LANG_STRING + " equalsIgnoreCase(" + JAVA_LANG_STRING + ")"); private static final MethodMatcher CONTENT_EQUALS = new MethodMatcher(JAVA_LANG_STRING + " contentEquals(java.lang.CharSequence)"); @@ -66,7 +68,11 @@ public Duration getEstimatedEffortPerOccurrence() { @Override public TreeVisitor getVisitor() { return Preconditions.check( - Preconditions.or(new UsesMethod<>(EQUALS), new UsesMethod<>(EQUALS_IGNORE_CASE), new UsesMethod<>(CONTENT_EQUALS)), + Preconditions.or( + new UsesMethod<>(EQUALS_STRING), + new UsesMethod<>(EQUALS_OBJECT), + new UsesMethod<>(EQUALS_IGNORE_CASE), + new UsesMethod<>(CONTENT_EQUALS)), new JavaVisitor() { @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { @@ -104,7 +110,8 @@ private boolean hasCompatibleArgument(J.MethodInvocation m) { } private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) { - return EQUALS.matches(methodInvocation) || + return EQUALS_STRING.matches(methodInvocation) || + EQUALS_OBJECT.matches(methodInvocation) || EQUALS_IGNORE_CASE.matches(methodInvocation) || CONTENT_EQUALS.matches(methodInvocation); } @@ -116,8 +123,8 @@ private void maybeHandleParentBinary(J.MethodInvocation m, final Tree parent) { J.Binary potentialNullCheck = (J.Binary) ((J.Binary) parent).getLeft(); if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) || - isNullLiteral(potentialNullCheck.getRight()) && - matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) { + isNullLiteral(potentialNullCheck.getRight()) && + matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) { doAfterVisit(new JavaVisitor() { private final J.Binary scope = (J.Binary) parent; diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index 203976ee9..10ea0821b 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -113,6 +113,31 @@ void foo(String s) { ); } + @Test + void ObjectEquals() { + rewriteRun( + //language=java + java( + """ + class A { + void foo(Object s) { + if (s.equals("null")) { + } + } + } + """, + """ + class A { + void foo(Object s) { + if ("null".equals(s)) { + } + } + } + """ + ) + ); + } + @Nested class ReplaceConstantMethodArg {