diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java b/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java new file mode 100644 index 0000000000..5d1b6868e7 --- /dev/null +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.openrewrite.ExecutionContext; +import org.openrewrite.internal.ListUtils; +import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.Statement; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +class RemoveUnreachableCodeVisitor extends JavaVisitor { + + @Override + public J visitBlock(J.Block block, ExecutionContext executionContext) { + block = (J.Block) super.visitBlock(block, executionContext); + + List statements = block.getStatements(); + Optional maybeFirstJumpIndex = findFirstJump(statements); + if (!maybeFirstJumpIndex.isPresent()) { + return block; + } + int firstJumpIndex = maybeFirstJumpIndex.get(); + + if (firstJumpIndex == statements.size() - 1) { + // Jump is at the end of the block, so nothing to do + return block; + } + + List newStatements = + ListUtils.flatMap( + block.getStatements(), + (index, statement) -> { + if (index <= firstJumpIndex) { + return statement; + } + return Collections.emptyList(); + } + ); + + return block.withStatements(newStatements); + } + + private Optional findFirstJump(List statements) { + for (int i = 0; i < statements.size(); i++) { + Statement statement = statements.get(i); + if ( + statement instanceof J.Return || + statement instanceof J.Throw || + statement instanceof J.Break || + statement instanceof J.Continue + ) { + return Optional.of(i); + } + } + return Optional.empty(); + } +} diff --git a/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java b/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java index b3273e3e6f..8e61a25f78 100644 --- a/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java +++ b/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java @@ -20,7 +20,6 @@ import org.openrewrite.SourceFile; import org.openrewrite.TreeVisitor; import org.openrewrite.internal.lang.Nullable; -import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.JavaVisitor; import org.openrewrite.java.cleanup.SimplifyBooleanExpressionVisitor; import org.openrewrite.java.cleanup.UnnecessaryParenthesesVisitor; @@ -32,7 +31,6 @@ import org.openrewrite.java.tree.Statement; import java.util.Optional; -import java.util.concurrent.atomic.AtomicBoolean; public class SimplifyConstantIfBranchExecution extends Recipe { @@ -93,10 +91,6 @@ public J visitIf(J.If if_, ExecutionContext context) { J.ControlParentheses cp = cleanupBooleanExpression(if__.getIfCondition(), context); if__ = if__.withIfCondition(cp); - if (visitsKeyWord(if__)) { - return if__; - } - // The compile-time constant value of the if condition control parentheses. final Optional compileTimeConstantBoolean; if (isLiteralTrue(cp.getTree())) { @@ -115,6 +109,7 @@ public J visitIf(J.If if_, ExecutionContext context) { // True branch // Only keep the `then` branch, and remove the `else` branch. Statement s = if__.getThenPart().withPrefix(if__.getPrefix()); + doAfterVisit(new RemoveUnreachableCodeVisitor()); return maybeAutoFormat( if__, s, @@ -126,6 +121,7 @@ public J visitIf(J.If if_, ExecutionContext context) { if (if__.getElsePart() != null) { // The `else` part needs to be kept Statement s = if__.getElsePart().getBody().withPrefix(if__.getPrefix()); + doAfterVisit(new RemoveUnreachableCodeVisitor()); return maybeAutoFormat( if__, s, @@ -151,41 +147,6 @@ public J visitIf(J.If if_, ExecutionContext context) { } } - private boolean visitsKeyWord(J.If iff) { - if (isLiteralFalse(iff.getIfCondition().getTree())) { - return false; - } - - AtomicBoolean visitedCFKeyword = new AtomicBoolean(false); - // if there is a return, break, continue, throws in _then, then set visitedKeyword to true - new JavaIsoVisitor() { - @Override - public J.Return visitReturn(J.Return _return, AtomicBoolean atomicBoolean) { - atomicBoolean.set(true); - return _return; - } - - @Override - public J.Continue visitContinue(J.Continue continueStatement, AtomicBoolean atomicBoolean) { - atomicBoolean.set(true); - return continueStatement; - } - - @Override - public J.Break visitBreak(J.Break breakStatement, AtomicBoolean atomicBoolean) { - atomicBoolean.set(true); - return breakStatement; - } - - @Override - public J.Throw visitThrow(J.Throw thrown, AtomicBoolean atomicBoolean) { - atomicBoolean.set(true); - return thrown; - } - }.visit(iff.getThenPart(), visitedCFKeyword); - return visitedCFKeyword.get(); - } - private static boolean isLiteralTrue(@Nullable Expression expression) { return J.Literal.isLiteralValue(expression, Boolean.TRUE); } diff --git a/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java b/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java index 6e1ab060cf..9393e24dfa 100644 --- a/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java @@ -666,18 +666,54 @@ public void test() { } @Test - void doesNotRemoveWhenReturnInIfBlock() { + void removesWhenReturnInThenBlock() { rewriteRun( //language=java java( """ public class A { public void test() { + System.out.println("before"); if (true) { - System.out.println("hello"); + System.out.println("then"); return; } - System.out.println("goodbye"); + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("then"); + return; + } + } + """ + ) + ); + } + + @Test + void removesWhenReturnInThenNoBlock() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before"); + if (true) return; + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + return; } } """ @@ -686,18 +722,151 @@ public void test() { } @Test - void doesNotRemoveWhenThrowsInIfBlock() { + void removesWhenReturnInThenBlockWithElse() { rewriteRun( //language=java java( """ public class A { public void test() { + System.out.println("before"); if (true) { - System.out.println("hello"); + System.out.println("then"); + return; + } else { + System.out.println("else"); + } + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("then"); + return; + } + }""" + ) + ); + } + + @Test + void removesWhenReturnInElseBlock() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before"); + if (false) { + System.out.println("then"); + } else { + System.out.println("else"); + return; + } + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("else"); + return; + } + } + """ + ) + ); + } + + @Test + void removesWhenReturnInElseNoBlock() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before"); + if (false) { + System.out.println("then"); + } else return; + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + return; + } + } + """ + ) + ); + } + + @Test + void removesWhenThrowsInThenBlock() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before"); + if (true) { + System.out.println("then"); throw new RuntimeException(); } - System.out.println("goodbye"); + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("then"); + throw new RuntimeException(); + } + } + """ + ) + ); + } + + @Test + void removesWhenThrowsInThenBlockWithElse() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before"); + if (true) { + System.out.println("then"); + throw new RuntimeException(); + } else { + System.out.println("else"); + } + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("then"); + throw new RuntimeException(); } } """ @@ -706,21 +875,68 @@ public void test() { } @Test - void doesNotRemoveWhenBreakInIfBlockWithinWhile() { + void removesWhenThrowsInElseBlock() { rewriteRun( //language=java java( """ public class A { public void test() { - while (true){ + System.out.println("before"); + if (false) { + System.out.println("then"); + } else { + System.out.println("else"); + throw new RuntimeException(); + } + System.out.println("after"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before"); + System.out.println("else"); + throw new RuntimeException(); + } + } + """ + ) + ); + } + + @Test + void removesWhenBreakInThenBlockWithinWhile() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); if (true) { - System.out.println("hello"); + System.out.println("then"); break; } - System.out.println("goodbye"); + System.out.println("after if"); } - System.out.println("goodbye"); + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("then"); + break; + } + System.out.println("after while"); } } """ @@ -729,21 +945,231 @@ public void test() { } @Test - void doesNotRemoveWhenContinueInIfBlockWithinWhile() { + void removesWhenBreakInThenBlockWithElseWithinWhile() { rewriteRun( //language=java java( """ public class A { public void test() { + System.out.println("before while"); while (true) { + System.out.println("before if"); if (true) { - System.out.println("hello"); + System.out.println("then"); + break; + } else { + System.out.println("else"); + } + System.out.println("after if"); + } + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("then"); + break; + } + System.out.println("after while"); + } + } + """ + ) + ); + } + + @Test + void removesWhenBreakInElseBlockWithinWhile() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + if (false) { + System.out.println("then"); + } else { + System.out.println("else"); + break; + } + System.out.println("after if"); + } + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("else"); + break; + } + System.out.println("after while"); + } + } + """ + ) + ); + } + + @Test + void removesWhenContinueInThenBlockWithinWhile() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + if (true) { + System.out.println("then"); continue; } - System.out.println("goodbye"); + System.out.println("after if"); } - System.out.println("goodbye"); + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("then"); + continue; + } + System.out.println("after while"); + } + } + """ + ) + ); + } + + @Test + void removesWhenContinueInThenBlockWithElseWithinWhile() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + if (true) { + System.out.println("then"); + continue; + } else { + System.out.println("else"); + } + System.out.println("after if"); + } + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("then"); + continue; + } + System.out.println("after while"); + } + } + """ + ) + ); + } + + @Test + void removesWhenContinueInElseBlockWithinWhile() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + if (false) { + System.out.println("then"); + } else { + System.out.println("else"); + continue; + } + System.out.println("after if"); + } + System.out.println("after while"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before while"); + while (true) { + System.out.println("before if"); + System.out.println("else"); + continue; + } + System.out.println("after while"); + } + } + """ + ) + ); + } + + @Test + void removesNestedWithReturn() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before outer if"); + if (true) { + System.out.println("outer then"); + if (true) { + System.out.println("inner then"); + return; + } + System.out.println("after inner if"); + } + System.out.println("after outer if"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before outer if"); + System.out.println("outer then"); + System.out.println("inner then"); + return; } } """ @@ -751,6 +1177,54 @@ public void test() { ); } + @Test + void simplifyNestedWithReturnAndThrowInTryCatch() { + rewriteRun( + //language=java + java( + """ + public class A { + public void test() { + System.out.println("before outer if"); + if (true) { + System.out.println("outer then"); + if (true) { + try { + if(true) { + throw new RuntimeException("Explosion"); + } + return; + } catch (Exception ex) { + System.out.println("catch"); + } + } + System.out.println("after inner if"); + } + System.out.println("after outer if"); + } + } + """, + """ + public class A { + public void test() { + System.out.println("before outer if"); + System.out.println("outer then"); + try { + throw new RuntimeException("Explosion"); + } catch (Exception ex) { + System.out.println("catch"); + } + System.out.println("after inner if"); + System.out.println("after outer if"); + } + } + """ + ) + ); + } + + + @Test void binaryOrIsAlwaysFalse() { rewriteRun(