diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java b/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java
new file mode 100644
index 0000000000..5d1b6868e7
--- /dev/null
+++ b/src/main/java/org/openrewrite/staticanalysis/RemoveUnreachableCodeVisitor.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright 2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.openrewrite.staticanalysis;
+
+import org.openrewrite.ExecutionContext;
+import org.openrewrite.internal.ListUtils;
+import org.openrewrite.java.JavaVisitor;
+import org.openrewrite.java.tree.J;
+import org.openrewrite.java.tree.Statement;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+class RemoveUnreachableCodeVisitor extends JavaVisitor {
+
+ @Override
+ public J visitBlock(J.Block block, ExecutionContext executionContext) {
+ block = (J.Block) super.visitBlock(block, executionContext);
+
+ List statements = block.getStatements();
+ Optional maybeFirstJumpIndex = findFirstJump(statements);
+ if (!maybeFirstJumpIndex.isPresent()) {
+ return block;
+ }
+ int firstJumpIndex = maybeFirstJumpIndex.get();
+
+ if (firstJumpIndex == statements.size() - 1) {
+ // Jump is at the end of the block, so nothing to do
+ return block;
+ }
+
+ List newStatements =
+ ListUtils.flatMap(
+ block.getStatements(),
+ (index, statement) -> {
+ if (index <= firstJumpIndex) {
+ return statement;
+ }
+ return Collections.emptyList();
+ }
+ );
+
+ return block.withStatements(newStatements);
+ }
+
+ private Optional findFirstJump(List statements) {
+ for (int i = 0; i < statements.size(); i++) {
+ Statement statement = statements.get(i);
+ if (
+ statement instanceof J.Return ||
+ statement instanceof J.Throw ||
+ statement instanceof J.Break ||
+ statement instanceof J.Continue
+ ) {
+ return Optional.of(i);
+ }
+ }
+ return Optional.empty();
+ }
+}
diff --git a/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java b/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java
index b3273e3e6f..8e61a25f78 100644
--- a/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java
+++ b/src/main/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecution.java
@@ -20,7 +20,6 @@
import org.openrewrite.SourceFile;
import org.openrewrite.TreeVisitor;
import org.openrewrite.internal.lang.Nullable;
-import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.cleanup.SimplifyBooleanExpressionVisitor;
import org.openrewrite.java.cleanup.UnnecessaryParenthesesVisitor;
@@ -32,7 +31,6 @@
import org.openrewrite.java.tree.Statement;
import java.util.Optional;
-import java.util.concurrent.atomic.AtomicBoolean;
public class SimplifyConstantIfBranchExecution extends Recipe {
@@ -93,10 +91,6 @@ public J visitIf(J.If if_, ExecutionContext context) {
J.ControlParentheses cp = cleanupBooleanExpression(if__.getIfCondition(), context);
if__ = if__.withIfCondition(cp);
- if (visitsKeyWord(if__)) {
- return if__;
- }
-
// The compile-time constant value of the if condition control parentheses.
final Optional compileTimeConstantBoolean;
if (isLiteralTrue(cp.getTree())) {
@@ -115,6 +109,7 @@ public J visitIf(J.If if_, ExecutionContext context) {
// True branch
// Only keep the `then` branch, and remove the `else` branch.
Statement s = if__.getThenPart().withPrefix(if__.getPrefix());
+ doAfterVisit(new RemoveUnreachableCodeVisitor());
return maybeAutoFormat(
if__,
s,
@@ -126,6 +121,7 @@ public J visitIf(J.If if_, ExecutionContext context) {
if (if__.getElsePart() != null) {
// The `else` part needs to be kept
Statement s = if__.getElsePart().getBody().withPrefix(if__.getPrefix());
+ doAfterVisit(new RemoveUnreachableCodeVisitor());
return maybeAutoFormat(
if__,
s,
@@ -151,41 +147,6 @@ public J visitIf(J.If if_, ExecutionContext context) {
}
}
- private boolean visitsKeyWord(J.If iff) {
- if (isLiteralFalse(iff.getIfCondition().getTree())) {
- return false;
- }
-
- AtomicBoolean visitedCFKeyword = new AtomicBoolean(false);
- // if there is a return, break, continue, throws in _then, then set visitedKeyword to true
- new JavaIsoVisitor() {
- @Override
- public J.Return visitReturn(J.Return _return, AtomicBoolean atomicBoolean) {
- atomicBoolean.set(true);
- return _return;
- }
-
- @Override
- public J.Continue visitContinue(J.Continue continueStatement, AtomicBoolean atomicBoolean) {
- atomicBoolean.set(true);
- return continueStatement;
- }
-
- @Override
- public J.Break visitBreak(J.Break breakStatement, AtomicBoolean atomicBoolean) {
- atomicBoolean.set(true);
- return breakStatement;
- }
-
- @Override
- public J.Throw visitThrow(J.Throw thrown, AtomicBoolean atomicBoolean) {
- atomicBoolean.set(true);
- return thrown;
- }
- }.visit(iff.getThenPart(), visitedCFKeyword);
- return visitedCFKeyword.get();
- }
-
private static boolean isLiteralTrue(@Nullable Expression expression) {
return J.Literal.isLiteralValue(expression, Boolean.TRUE);
}
diff --git a/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java b/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java
index 6e1ab060cf..9393e24dfa 100644
--- a/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java
+++ b/src/test/java/org/openrewrite/staticanalysis/SimplifyConstantIfBranchExecutionTest.java
@@ -666,18 +666,54 @@ public void test() {
}
@Test
- void doesNotRemoveWhenReturnInIfBlock() {
+ void removesWhenReturnInThenBlock() {
rewriteRun(
//language=java
java(
"""
public class A {
public void test() {
+ System.out.println("before");
if (true) {
- System.out.println("hello");
+ System.out.println("then");
return;
}
- System.out.println("goodbye");
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("then");
+ return;
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenReturnInThenNoBlock() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ if (true) return;
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ return;
}
}
"""
@@ -686,18 +722,151 @@ public void test() {
}
@Test
- void doesNotRemoveWhenThrowsInIfBlock() {
+ void removesWhenReturnInThenBlockWithElse() {
rewriteRun(
//language=java
java(
"""
public class A {
public void test() {
+ System.out.println("before");
if (true) {
- System.out.println("hello");
+ System.out.println("then");
+ return;
+ } else {
+ System.out.println("else");
+ }
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("then");
+ return;
+ }
+ }"""
+ )
+ );
+ }
+
+ @Test
+ void removesWhenReturnInElseBlock() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ if (false) {
+ System.out.println("then");
+ } else {
+ System.out.println("else");
+ return;
+ }
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("else");
+ return;
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenReturnInElseNoBlock() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ if (false) {
+ System.out.println("then");
+ } else return;
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ return;
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenThrowsInThenBlock() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ if (true) {
+ System.out.println("then");
throw new RuntimeException();
}
- System.out.println("goodbye");
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("then");
+ throw new RuntimeException();
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenThrowsInThenBlockWithElse() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ if (true) {
+ System.out.println("then");
+ throw new RuntimeException();
+ } else {
+ System.out.println("else");
+ }
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("then");
+ throw new RuntimeException();
}
}
"""
@@ -706,21 +875,68 @@ public void test() {
}
@Test
- void doesNotRemoveWhenBreakInIfBlockWithinWhile() {
+ void removesWhenThrowsInElseBlock() {
rewriteRun(
//language=java
java(
"""
public class A {
public void test() {
- while (true){
+ System.out.println("before");
+ if (false) {
+ System.out.println("then");
+ } else {
+ System.out.println("else");
+ throw new RuntimeException();
+ }
+ System.out.println("after");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before");
+ System.out.println("else");
+ throw new RuntimeException();
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenBreakInThenBlockWithinWhile() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
if (true) {
- System.out.println("hello");
+ System.out.println("then");
break;
}
- System.out.println("goodbye");
+ System.out.println("after if");
}
- System.out.println("goodbye");
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("then");
+ break;
+ }
+ System.out.println("after while");
}
}
"""
@@ -729,21 +945,231 @@ public void test() {
}
@Test
- void doesNotRemoveWhenContinueInIfBlockWithinWhile() {
+ void removesWhenBreakInThenBlockWithElseWithinWhile() {
rewriteRun(
//language=java
java(
"""
public class A {
public void test() {
+ System.out.println("before while");
while (true) {
+ System.out.println("before if");
if (true) {
- System.out.println("hello");
+ System.out.println("then");
+ break;
+ } else {
+ System.out.println("else");
+ }
+ System.out.println("after if");
+ }
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("then");
+ break;
+ }
+ System.out.println("after while");
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenBreakInElseBlockWithinWhile() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ if (false) {
+ System.out.println("then");
+ } else {
+ System.out.println("else");
+ break;
+ }
+ System.out.println("after if");
+ }
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("else");
+ break;
+ }
+ System.out.println("after while");
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenContinueInThenBlockWithinWhile() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ if (true) {
+ System.out.println("then");
continue;
}
- System.out.println("goodbye");
+ System.out.println("after if");
}
- System.out.println("goodbye");
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("then");
+ continue;
+ }
+ System.out.println("after while");
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenContinueInThenBlockWithElseWithinWhile() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ if (true) {
+ System.out.println("then");
+ continue;
+ } else {
+ System.out.println("else");
+ }
+ System.out.println("after if");
+ }
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("then");
+ continue;
+ }
+ System.out.println("after while");
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesWhenContinueInElseBlockWithinWhile() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ if (false) {
+ System.out.println("then");
+ } else {
+ System.out.println("else");
+ continue;
+ }
+ System.out.println("after if");
+ }
+ System.out.println("after while");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before while");
+ while (true) {
+ System.out.println("before if");
+ System.out.println("else");
+ continue;
+ }
+ System.out.println("after while");
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void removesNestedWithReturn() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before outer if");
+ if (true) {
+ System.out.println("outer then");
+ if (true) {
+ System.out.println("inner then");
+ return;
+ }
+ System.out.println("after inner if");
+ }
+ System.out.println("after outer if");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before outer if");
+ System.out.println("outer then");
+ System.out.println("inner then");
+ return;
}
}
"""
@@ -751,6 +1177,54 @@ public void test() {
);
}
+ @Test
+ void simplifyNestedWithReturnAndThrowInTryCatch() {
+ rewriteRun(
+ //language=java
+ java(
+ """
+ public class A {
+ public void test() {
+ System.out.println("before outer if");
+ if (true) {
+ System.out.println("outer then");
+ if (true) {
+ try {
+ if(true) {
+ throw new RuntimeException("Explosion");
+ }
+ return;
+ } catch (Exception ex) {
+ System.out.println("catch");
+ }
+ }
+ System.out.println("after inner if");
+ }
+ System.out.println("after outer if");
+ }
+ }
+ """,
+ """
+ public class A {
+ public void test() {
+ System.out.println("before outer if");
+ System.out.println("outer then");
+ try {
+ throw new RuntimeException("Explosion");
+ } catch (Exception ex) {
+ System.out.println("catch");
+ }
+ System.out.println("after inner if");
+ System.out.println("after outer if");
+ }
+ }
+ """
+ )
+ );
+ }
+
+
+
@Test
void binaryOrIsAlwaysFalse() {
rewriteRun(