diff --git a/src/main/java/org/openrewrite/staticanalysis/SimplifyBooleanReturn.java b/src/main/java/org/openrewrite/staticanalysis/SimplifyBooleanReturn.java index 118455179d..45bb714b84 100644 --- a/src/main/java/org/openrewrite/staticanalysis/SimplifyBooleanReturn.java +++ b/src/main/java/org/openrewrite/staticanalysis/SimplifyBooleanReturn.java @@ -70,18 +70,24 @@ public J visitIf(J.If iff, ExecutionContext ctx) { Cursor parent = getCursor().getParentTreeCursor(); if (parent.getValue() instanceof J.Block && - parent.getParentOrThrow().getValue() instanceof J.MethodDeclaration && - thenHasOnlyReturnStatement(iff) && - elseWithOnlyReturn(i)) { + parent.getParentOrThrow().getValue() instanceof J.MethodDeclaration && + thenHasOnlyReturnStatement(iff) && + elseWithOnlyReturn(i)) { List followingStatements = followingStatements(); Optional singleFollowingStatement = Optional.ofNullable(followingStatements.isEmpty() ? null : followingStatements.get(0)) .flatMap(stat -> Optional.ofNullable(stat instanceof J.Return ? (J.Return) stat : null)) + .filter(r -> r.getComments().isEmpty()) .map(J.Return::getExpression); if (followingStatements.isEmpty() || singleFollowingStatement.map(r -> isLiteralFalse(r) || isLiteralTrue(r)).orElse(false)) { J.Return return_ = getReturnIfOnlyStatementInThen(iff).orElse(null); assert return_ != null; + // Do not remove comments that are attached to the return statement + if (!return_.getComments().isEmpty() || hasElseWithComment(i.getElsePart())) { + return i; + } + Expression ifCondition = i.getIfCondition().getTree(); if (isLiteralTrue(return_.getExpression())) { @@ -89,7 +95,7 @@ public J visitIf(J.If iff, ExecutionContext ctx) { doAfterVisit(new DeleteStatement<>(followingStatements().get(0))); return maybeAutoFormat(return_, return_.withExpression(ifCondition), ctx, parent); } else if (!singleFollowingStatement.isPresent() && - getReturnExprIfOnlyStatementInElseThen(i).map(this::isLiteralFalse).orElse(false)) { + getReturnExprIfOnlyStatementInElseThen(i).map(this::isLiteralFalse).orElse(false)) { if (i.getElsePart() != null) { doAfterVisit(new DeleteStatement<>(i.getElsePart().getBody())); } @@ -185,6 +191,23 @@ private Optional getReturnExprIfOnlyStatementInElseThen(J.If iff2) { return Optional.empty(); } + + private boolean hasElseWithComment(J.If.Else else_) { + if (else_ == null || else_.getBody() == null) { + return false; + } + if (!else_.getComments().isEmpty()) { + return true; + } + if (!else_.getBody().getComments().isEmpty()) { + return true; + } + if (else_.getBody() instanceof J.Block + && !((J.Block) else_.getBody()).getStatements().get(0).getComments().isEmpty()) { + return true; + } + return false; + } }; } diff --git a/src/test/java/org/openrewrite/staticanalysis/SimplifyBooleanReturnTest.java b/src/test/java/org/openrewrite/staticanalysis/SimplifyBooleanReturnTest.java index 02072ecdc7..81418513ca 100644 --- a/src/test/java/org/openrewrite/staticanalysis/SimplifyBooleanReturnTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/SimplifyBooleanReturnTest.java @@ -15,6 +15,7 @@ */ package org.openrewrite.staticanalysis; +import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.test.RecipeSpec; @@ -34,6 +35,7 @@ public void defaults(RecipeSpec spec) { @Test void simplifyBooleanReturn() { rewriteRun( + //language=java java( """ public class A { @@ -74,6 +76,7 @@ static boolean isOddMillis() { @Test void dontSimplifyToReturnUnlessLastStatement() { rewriteRun( + //language=java java( """ public class A { @@ -105,6 +108,7 @@ public boolean absurdEquals(Object o) { @Test void nestedIfsWithNoBlock() { rewriteRun( + //language=java java( """ public class A { @@ -123,6 +127,7 @@ public boolean absurdEquals(Object o) { @Test void dontAlterWhenElseIfPresent() { rewriteRun( + //language=java java( """ public class A { @@ -146,6 +151,7 @@ else if (n == 2) { @Test void dontAlterWhenElseContainsSomethingOtherThanReturn() { rewriteRun( + //language=java java( """ public class A { @@ -167,6 +173,7 @@ public boolean foo(int n) { @Test void onlySimplifyToReturnWhenLastStatement() { rewriteRun( + //language=java java( """ import java.util.*; @@ -188,6 +195,7 @@ public static boolean deepEquals(List l, List r) { @Test void wrapNotReturnsOfTernaryIfConditionsInParentheses() { rewriteRun( + //language=java java( """ public class A { @@ -211,4 +219,92 @@ public boolean equals(Object o) { ) ); } + + @Nested + class RetainComments { + @Test + void onIfReturn() { + rewriteRun( + //language=java + java( + """ + class A { + boolean foo(int n) { + if (n == 1) { + // A comment that provides important context + return true; + } + else { + return false; + } + } + } + """ + ) + ); + } + + @Test + void onElseBlockReturn() { + rewriteRun( + //language=java + java( + """ + class A { + boolean foo(int n) { + if (n == 1) { + return true; + } + else { + // A comment that provides important context + return false; + } + } + } + """ + ) + ); + } + + @Test + void onElseReturn() { + rewriteRun( + //language=java + java( + """ + class A { + boolean foo(int n) { + if (n == 1) { + return true; + } + else + // A comment that provides important context + return false; + } + } + """ + ) + ); + } + + @Test + void onImpliedElse() { + rewriteRun( + //language=java + java( + """ + class A { + boolean foo(int n) { + if (n == 1) { + return true; + } + // A comment that provides important context + return false; + } + } + """ + ) + ); + } + } }