diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java index 3ef6ea774b..22fd10ef32 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java @@ -80,25 +80,45 @@ public TreeVisitor getVisitor() { new JavaVisitor() { @Override public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { - J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); + J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx); - if (!isStringComparisonMethod(m)) { - return m; + if (mi.getSelect() instanceof J.Literal || !isStringComparisonMethod(mi) || !hasCompatibleArgument(mi)) { + return mi; } - // Always check for redundant null checks, even if we won't swap arguments - maybeHandleParentBinary(m, getCursor().getParentTreeCursor().getValue()); + Expression firstArgument = mi.getArguments().get(0); + return firstArgument.getType() == JavaType.Primitive.Null ? + literalsFirstInComparisonsNull(mi, firstArgument) : + literalsFirstInComparisons(mi, firstArgument); + } - if (!hasCompatibleArgument(m) || m.getSelect() instanceof J.Literal) { - return m; + @Override + public J visitBinary(J.Binary binary, ExecutionContext ctx) { + // First swap order of method invocation select and argument + J.Binary b = (J.Binary) super.visitBinary(binary, ctx); + + // Independent of changes above, clear out unnecessary null comparisons + if (b.getLeft() instanceof J.Binary && + b.getOperator() == J.Binary.Type.And && + isStringComparisonMethod(b.getRight())) { + Expression nullCheckedLeft = nullCheckedArgument((J.Binary) b.getLeft()); + if (nullCheckedLeft != null && areEqual(nullCheckedLeft, ((J.MethodInvocation) b.getRight()).getArguments().get(0))) { + return b.getRight().withPrefix(b.getPrefix()); + } } + return b; + } - Expression firstArgument = m.getArguments().get(0); - - return firstArgument.getType() == JavaType.Primitive.Null ? - literalsFirstInComparisonsNull(m, firstArgument) : - literalsFirstInComparisons(m, firstArgument); - + private @Nullable Expression nullCheckedArgument(J.Binary binary) { + if (binary.getOperator() == J.Binary.Type.NotEqual) { + if (isLiteralValue(binary.getLeft(), null)) { + return binary.getRight(); + } + if (isLiteralValue(binary.getRight(), null)) { + return binary.getLeft(); + } + } + return null; } private boolean hasCompatibleArgument(J.MethodInvocation m) { @@ -119,47 +139,15 @@ private boolean hasCompatibleArgument(J.MethodInvocation m) { return false; } - private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) { - return EQUALS_STRING.matches(methodInvocation) || - (EQUALS_OBJECT.matches(methodInvocation) && TypeUtils.isString(methodInvocation.getArguments().get(0).getType())) || - EQUALS_IGNORE_CASE.matches(methodInvocation) || - CONTENT_EQUALS.matches(methodInvocation); - } - - private void maybeHandleParentBinary(J.MethodInvocation m, final Tree parent) { - if (parent instanceof J.Binary) { - if (((J.Binary) parent).getOperator() == J.Binary.Type.And && - ((J.Binary) parent).getLeft() instanceof J.Binary) { - J.Binary potentialNullCheck = (J.Binary) ((J.Binary) parent).getLeft(); - boolean nullCheckMatchesSelect = - (isLiteralValue(potentialNullCheck.getLeft(), null) && areEqual(potentialNullCheck.getRight(), m.getSelect())) || - (isLiteralValue(potentialNullCheck.getRight(), null) && areEqual(potentialNullCheck.getLeft(), m.getSelect())); - boolean nullCheckMatchesArgument = - (isLiteralValue(potentialNullCheck.getLeft(), null) && areEqual(potentialNullCheck.getRight(), m.getArguments().get(0))) || - (isLiteralValue(potentialNullCheck.getRight(), null) && areEqual(potentialNullCheck.getLeft(), m.getArguments().get(0))); - if (nullCheckMatchesSelect || nullCheckMatchesArgument) { - doAfterVisit(new JavaVisitor() { - - private final J.Binary scope = (J.Binary) parent; - private boolean done; - - @Override - public @Nullable J visit(@Nullable Tree tree, ExecutionContext ctx) { - return done ? (J) tree : super.visit(tree, ctx); - } - - @Override - public J visitBinary(J.Binary binary, ExecutionContext ctx) { - if (scope.isScope(binary)) { - done = true; - return binary.getRight().withPrefix(binary.getPrefix()); - } - return super.visitBinary(binary, ctx); - } - }); - } - } + private boolean isStringComparisonMethod(J j) { + if (j instanceof J.MethodInvocation) { + J.MethodInvocation mi = (J.MethodInvocation) j; + return EQUALS_STRING.matches(mi) || + (EQUALS_OBJECT.matches(mi) && TypeUtils.isString(mi.getArguments().get(0).getType())) || + EQUALS_IGNORE_CASE.matches(mi) || + CONTENT_EQUALS.matches(mi); } + return false; } private J.Binary literalsFirstInComparisonsNull(J.MethodInvocation m, Expression firstArgument) { diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index a9200222b4..261ec98871 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -87,23 +87,50 @@ void removeUnnecessaryNullCheck() { java( """ public class A { - { - String s = null, t = null; - if(s != null && s.equals("test")) {} - if(null != s && s.equals("test")) {} - if(t != null && "test".equals(t)) {} - if(null != t && "test".equals(t)) {} + void check(String s, String t) { + if (s != null && s.equals("test")) {} + if (null != s && s.equals("test")) {} + if (t != null && "test".equals(t)) {} + if (null != t && "test".equals(t)) {} } } """, """ public class A { - { - String s = null, t = null; - if("test".equals(s)) {} - if("test".equals(s)) {} - if("test".equals(t)) {} - if("test".equals(t)) {} + void check(String s, String t) { + if ("test".equals(s)) {} + if ("test".equals(s)) {} + if ("test".equals(t)) {} + if ("test".equals(t)) {} + } + } + """ + ) + ); + } + + @Test + void retainNecessaryNullCheck() { + rewriteRun( + //language=java + java( + """ + class A { + void check(String expected, String actual){ + if (expected != null && expected.equals(actual)) {} + if (actual != null && actual.equals(expected)) {} + if (expected != null && actual.equals(expected)) {} + if (actual != null && expected.equals(actual)) {} + } + } + """, + """ + class A { + void check(String expected, String actual){ + if (expected != null && expected.equals(actual)) {} + if (actual != null && actual.equals(expected)) {} + if (actual.equals(expected)) {} + if (expected.equals(actual)) {} } } """