Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 41 additions & 53 deletions src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,45 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
new JavaVisitor<ExecutionContext>() {
@Override
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);
J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);

if (!isStringComparisonMethod(m)) {
return m;
if (mi.getSelect() instanceof J.Literal || !isStringComparisonMethod(mi) || !hasCompatibleArgument(mi)) {
return mi;
}

// Always check for redundant null checks, even if we won't swap arguments
maybeHandleParentBinary(m, getCursor().getParentTreeCursor().getValue());
Expression firstArgument = mi.getArguments().get(0);
return firstArgument.getType() == JavaType.Primitive.Null ?
literalsFirstInComparisonsNull(mi, firstArgument) :
literalsFirstInComparisons(mi, firstArgument);
}

if (!hasCompatibleArgument(m) || m.getSelect() instanceof J.Literal) {
return m;
@Override
public J visitBinary(J.Binary binary, ExecutionContext ctx) {
// First swap order of method invocation select and argument
J.Binary b = (J.Binary) super.visitBinary(binary, ctx);

// Independent of changes above, clear out unnecessary null comparisons
if (b.getLeft() instanceof J.Binary &&
b.getOperator() == J.Binary.Type.And &&
isStringComparisonMethod(b.getRight())) {
Expression nullCheckedLeft = nullCheckedArgument((J.Binary) b.getLeft());
if (nullCheckedLeft != null && areEqual(nullCheckedLeft, ((J.MethodInvocation) b.getRight()).getArguments().get(0))) {
return b.getRight().withPrefix(b.getPrefix());
}
}
return b;
}

Expression firstArgument = m.getArguments().get(0);

return firstArgument.getType() == JavaType.Primitive.Null ?
literalsFirstInComparisonsNull(m, firstArgument) :
literalsFirstInComparisons(m, firstArgument);

private @Nullable Expression nullCheckedArgument(J.Binary binary) {
if (binary.getOperator() == J.Binary.Type.NotEqual) {
if (isLiteralValue(binary.getLeft(), null)) {
return binary.getRight();
}
if (isLiteralValue(binary.getRight(), null)) {
return binary.getLeft();
}
}
return null;
}

private boolean hasCompatibleArgument(J.MethodInvocation m) {
Expand All @@ -119,47 +139,15 @@ private boolean hasCompatibleArgument(J.MethodInvocation m) {
return false;
}

private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) {
return EQUALS_STRING.matches(methodInvocation) ||
(EQUALS_OBJECT.matches(methodInvocation) && TypeUtils.isString(methodInvocation.getArguments().get(0).getType())) ||
EQUALS_IGNORE_CASE.matches(methodInvocation) ||
CONTENT_EQUALS.matches(methodInvocation);
}

private void maybeHandleParentBinary(J.MethodInvocation m, final Tree parent) {
if (parent instanceof J.Binary) {
if (((J.Binary) parent).getOperator() == J.Binary.Type.And &&
((J.Binary) parent).getLeft() instanceof J.Binary) {
J.Binary potentialNullCheck = (J.Binary) ((J.Binary) parent).getLeft();
boolean nullCheckMatchesSelect =
(isLiteralValue(potentialNullCheck.getLeft(), null) && areEqual(potentialNullCheck.getRight(), m.getSelect())) ||
(isLiteralValue(potentialNullCheck.getRight(), null) && areEqual(potentialNullCheck.getLeft(), m.getSelect()));
boolean nullCheckMatchesArgument =
(isLiteralValue(potentialNullCheck.getLeft(), null) && areEqual(potentialNullCheck.getRight(), m.getArguments().get(0))) ||
(isLiteralValue(potentialNullCheck.getRight(), null) && areEqual(potentialNullCheck.getLeft(), m.getArguments().get(0)));
if (nullCheckMatchesSelect || nullCheckMatchesArgument) {
doAfterVisit(new JavaVisitor<ExecutionContext>() {

private final J.Binary scope = (J.Binary) parent;
private boolean done;

@Override
public @Nullable J visit(@Nullable Tree tree, ExecutionContext ctx) {
return done ? (J) tree : super.visit(tree, ctx);
}

@Override
public J visitBinary(J.Binary binary, ExecutionContext ctx) {
if (scope.isScope(binary)) {
done = true;
return binary.getRight().withPrefix(binary.getPrefix());
}
return super.visitBinary(binary, ctx);
}
});
}
}
private boolean isStringComparisonMethod(J j) {
if (j instanceof J.MethodInvocation) {
J.MethodInvocation mi = (J.MethodInvocation) j;
return EQUALS_STRING.matches(mi) ||
(EQUALS_OBJECT.matches(mi) && TypeUtils.isString(mi.getArguments().get(0).getType())) ||
EQUALS_IGNORE_CASE.matches(mi) ||
CONTENT_EQUALS.matches(mi);
}
return false;
}

private J.Binary literalsFirstInComparisonsNull(J.MethodInvocation m, Expression firstArgument) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,50 @@ void removeUnnecessaryNullCheck() {
java(
"""
public class A {
{
String s = null, t = null;
if(s != null && s.equals("test")) {}
if(null != s && s.equals("test")) {}
if(t != null && "test".equals(t)) {}
if(null != t && "test".equals(t)) {}
void check(String s, String t) {
if (s != null && s.equals("test")) {}
if (null != s && s.equals("test")) {}
if (t != null && "test".equals(t)) {}
if (null != t && "test".equals(t)) {}
}
}
""",
"""
public class A {
{
String s = null, t = null;
if("test".equals(s)) {}
if("test".equals(s)) {}
if("test".equals(t)) {}
if("test".equals(t)) {}
void check(String s, String t) {
if ("test".equals(s)) {}
if ("test".equals(s)) {}
if ("test".equals(t)) {}
if ("test".equals(t)) {}
}
}
"""
)
);
}

@Test
void retainNecessaryNullCheck() {
rewriteRun(
//language=java
java(
"""
class A {
void check(String expected, String actual){
if (expected != null && expected.equals(actual)) {}
if (actual != null && actual.equals(expected)) {}
if (expected != null && actual.equals(expected)) {}
if (actual != null && expected.equals(actual)) {}
}
}
""",
"""
class A {
void check(String expected, String actual){
if (expected != null && expected.equals(actual)) {}
if (actual != null && actual.equals(expected)) {}
if (actual.equals(expected)) {}
if (expected.equals(actual)) {}
}
}
"""
Expand Down