diff --git a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java index 420d4eacbc..8ddc19a33d 100644 --- a/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java @@ -26,54 +26,91 @@ import org.openrewrite.marker.Markers; import static java.util.Collections.singletonList; +import static java.util.Objects.requireNonNull; +/** + * A visitor that identifies and addresses potential issues related to + * the use of {@code equals} methods in Java, particularly to avoid + * null pointer exceptions when comparing strings. + *

+ * This visitor looks for method invocations of {@code equals}, + * {@code equalsIgnoreCase}, {@code compareTo}, and {@code contentEquals}, + * and performs optimizations to ensure null checks are correctly applied. + *

+ * For more details, refer to the PMD best practices: + * Literals First in Comparisons + * + * @param

The type of the parent context used for visiting the AST. + */ @Value @EqualsAndHashCode(callSuper = false) public class EqualsAvoidsNullVisitor

extends JavaVisitor

{ - private static final MethodMatcher STRING_EQUALS = new MethodMatcher("String equals(java.lang.Object)"); - private static final MethodMatcher STRING_EQUALS_IGNORE_CASE = new MethodMatcher("String equalsIgnoreCase(java.lang.String)"); + + private static final MethodMatcher EQUALS = new MethodMatcher("java.lang.String " + "equals(java.lang.Object)"); + private static final MethodMatcher EQUALS_IGNORE_CASE = new MethodMatcher("java.lang.String " + "equalsIgnoreCase(java.lang.String)"); + private static final MethodMatcher COMPARE_TO = new MethodMatcher("java.lang.String " + "compareTo(java.lang.String)"); + private static final MethodMatcher COMPARE_TO_IGNORE_CASE = new MethodMatcher("java.lang.String " + "compareToIgnoreCase(java.lang.String)"); + private static final MethodMatcher CONTENT_EQUALS = new MethodMatcher("java.lang.String " + "contentEquals(java.lang.CharSequence)"); EqualsAvoidsNullStyle style; @Override public J visitMethodInvocation(J.MethodInvocation method, P p) { - J j = super.visitMethodInvocation(method, p); - if (!(j instanceof J.MethodInvocation)) { - return j; + J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, p); + if (m.getSelect() != null && + !(m.getSelect() instanceof J.Literal) && + m.getArguments().get(0) instanceof J.Literal && + isStringComparisonMethod(m)) { + return literalsFirstInComparisonsBinaryCheck(m, getCursor().getParentTreeCursor().getValue()); } - J.MethodInvocation m = (J.MethodInvocation) j; - if (m.getSelect() == null) { - return m; + return m; + } + + private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) { + return EQUALS.matches(methodInvocation) || + !style.getIgnoreEqualsIgnoreCase() && + EQUALS_IGNORE_CASE.matches(methodInvocation) || + COMPARE_TO.matches(methodInvocation) || + COMPARE_TO_IGNORE_CASE.matches(methodInvocation) || + CONTENT_EQUALS.matches(methodInvocation); + } + + private Expression literalsFirstInComparisonsBinaryCheck(J.MethodInvocation m, P parent) { + if (parent instanceof J.Binary) { + handleBinaryExpression(m, (J.Binary) parent); } + return getExpression(m, m.getArguments().get(0)); + } - if ((STRING_EQUALS.matches(m) || (!Boolean.TRUE.equals(style.getIgnoreEqualsIgnoreCase()) && STRING_EQUALS_IGNORE_CASE.matches(m))) && - m.getArguments().get(0) instanceof J.Literal && - !(m.getSelect() instanceof J.Literal)) { - Tree parent = getCursor().getParentTreeCursor().getValue(); - if (parent instanceof J.Binary) { - J.Binary binary = (J.Binary) parent; - if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) { - J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); - if ((isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), m.getSelect())) || - (isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), m.getSelect()))) { - doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); - } - } - } + private static Expression getExpression(J.MethodInvocation m, Expression firstArgument) { + return firstArgument.getType() == JavaType.Primitive.Null ? + literalsFirstInComparisonsNull(m, firstArgument) : + literalsFirstInComparisons(m, firstArgument); + } + + private static J.Binary literalsFirstInComparisonsNull(J.MethodInvocation m, Expression firstArgument) { + return new J.Binary(Tree.randomId(), + m.getPrefix(), + Markers.EMPTY, + requireNonNull(m.getSelect()), + JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE), + firstArgument.withPrefix(Space.SINGLE_SPACE), + JavaType.Primitive.Boolean); + } - if (m.getArguments().get(0).getType() == JavaType.Primitive.Null) { - return new J.Binary(Tree.randomId(), m.getPrefix(), Markers.EMPTY, - m.getSelect(), - JLeftPadded.build(J.Binary.Type.Equal).withBefore(Space.SINGLE_SPACE), - m.getArguments().get(0).withPrefix(Space.SINGLE_SPACE), - JavaType.Primitive.Boolean); - } else { - m = m.withSelect(((J.Literal) m.getArguments().get(0)).withPrefix(m.getSelect().getPrefix())) - .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); + private static J.MethodInvocation literalsFirstInComparisons(J.MethodInvocation m, Expression firstArgument) { + return m.withSelect(firstArgument.withPrefix(requireNonNull(m.getSelect()).getPrefix())) + .withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY))); + } + + private void handleBinaryExpression(J.MethodInvocation m, J.Binary binary) { + if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) { + J.Binary potentialNullCheck = (J.Binary) binary.getLeft(); + if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) || + isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) { + doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary)); } } - - return m; } private boolean isNullLiteral(Expression expression) { @@ -81,13 +118,20 @@ private boolean isNullLiteral(Expression expression) { } private boolean matchesSelect(Expression expression, Expression select) { - return expression.printTrimmed(getCursor()).replaceAll("\\s", "").equals(select.printTrimmed(getCursor()).replaceAll("\\s", "")); + return expression.printTrimmed(getCursor()).replaceAll("\\s", "") + .equals(select.printTrimmed(getCursor()).replaceAll("\\s", "")); } private static class RemoveUnnecessaryNullCheck

extends JavaVisitor

{ + private final J.Binary scope; + boolean done; + public RemoveUnnecessaryNullCheck(J.Binary scope) { + this.scope = scope; + } + @Override public @Nullable J visit(@Nullable Tree tree, P p) { if (done) { @@ -96,17 +140,12 @@ private static class RemoveUnnecessaryNullCheck

extends JavaVisitor

{ return super.visit(tree, p); } - public RemoveUnnecessaryNullCheck(J.Binary scope) { - this.scope = scope; - } - @Override public J visitBinary(J.Binary binary, P p) { if (scope.isScope(binary)) { done = true; return binary.getRight().withPrefix(Space.EMPTY); } - return super.visitBinary(binary, p); } } diff --git a/src/main/java/org/openrewrite/staticanalysis/UseStringReplace.java b/src/main/java/org/openrewrite/staticanalysis/UseStringReplace.java index a38a487cb2..ee21d9ace5 100644 --- a/src/main/java/org/openrewrite/staticanalysis/UseStringReplace.java +++ b/src/main/java/org/openrewrite/staticanalysis/UseStringReplace.java @@ -28,7 +28,6 @@ import java.time.Duration; import java.util.Collections; -import java.util.Objects; import java.util.Set; import java.util.regex.Pattern; @@ -90,7 +89,7 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) return invocation; // Might contain special characters; unsafe to replace } String secondValue = (String) ((J.Literal) secondArgument).getValue(); - if (Objects.nonNull(secondValue) && (secondValue.contains("$") || secondValue.contains("\\"))) { + if (secondValue != null && (secondValue.contains("$") || secondValue.contains("\\"))) { return invocation; // Does contain special characters; unsafe to replace } @@ -100,7 +99,7 @@ public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) // Checks if the String literal may not be a regular expression, // if so, then change the method invocation name String firstValue = (String) ((J.Literal) firstArgument).getValue(); - if (Objects.nonNull(firstValue) && !mayBeRegExp(firstValue)) { + if (firstValue != null && !mayBeRegExp(firstValue)) { String unEscapedLiteral = unEscapeCharacters(firstValue); invocation = invocation .withName(invocation.getName().withSimpleName("replace")) diff --git a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java index 68bb55236a..1dde174578 100644 --- a/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java @@ -42,6 +42,9 @@ public class A { String s = null; if(s.equals("test")) {} if(s.equalsIgnoreCase("test")) {} + System.out.println(s.compareTo("test")); + System.out.println(s.compareToIgnoreCase("test")); + System.out.println(s.contentEquals("test")); } } """, @@ -51,6 +54,9 @@ public class A { String s = null; if("test".equals(s)) {} if("test".equalsIgnoreCase(s)) {} + System.out.println("test".compareTo(s)); + System.out.println("test".compareToIgnoreCase(s)); + System.out.println("test".contentEquals(s)); } } """