diff --git a/src/main/java/org/openrewrite/staticanalysis/DefaultComesLastVisitor.java b/src/main/java/org/openrewrite/staticanalysis/DefaultComesLastVisitor.java index b9c651938f..b8ceb05cf5 100644 --- a/src/main/java/org/openrewrite/staticanalysis/DefaultComesLastVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/DefaultComesLastVisitor.java @@ -94,8 +94,7 @@ public J.Switch visitSwitch(J.Switch switch_, P p) { casesWithDefaultLast = addBreakToLastCase(casesWithDefaultLast, p); casesWithDefaultLast.addAll(maybeReorderFallthroughCases(defaultCases, p)); - casesWithDefaultLast = ListUtils.mapLast(casesWithDefaultLast, this::removeBreak); - return casesWithDefaultLast; + return ListUtils.mapLast(casesWithDefaultLast, this::removeBreak); } private List maybeReorderFallthroughCases(List cases, P p) { diff --git a/src/main/java/org/openrewrite/staticanalysis/FinalizeMethodArguments.java b/src/main/java/org/openrewrite/staticanalysis/FinalizeMethodArguments.java index 25d623961c..00d8ad0c5d 100644 --- a/src/main/java/org/openrewrite/staticanalysis/FinalizeMethodArguments.java +++ b/src/main/java/org/openrewrite/staticanalysis/FinalizeMethodArguments.java @@ -23,7 +23,10 @@ import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; import org.openrewrite.java.JavaIsoVisitor; -import org.openrewrite.java.tree.*; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.Space; +import org.openrewrite.java.tree.Statement; import org.openrewrite.marker.Markers; import java.util.List; @@ -69,8 +72,7 @@ public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDecl } List parameters = ListUtils.map(declarations.getParameters(), FinalizeMethodArguments::updateParam); - declarations = declarations.withParameters(parameters); - return declarations; + return declarations.withParameters(parameters); } private void checkIfAssigned(final AtomicBoolean assigned, final Statement p) { @@ -176,8 +178,7 @@ private static Statement updateParam(final Statement p) { J.VariableDeclarations variableDeclarations = (J.VariableDeclarations) p; if (variableDeclarations.getModifiers().isEmpty()) { variableDeclarations = updateModifiers(variableDeclarations, !((J.VariableDeclarations) p).getLeadingAnnotations().isEmpty()); - variableDeclarations = updateDeclarations(variableDeclarations); - return variableDeclarations; + return updateDeclarations(variableDeclarations); } } return p; diff --git a/src/main/java/org/openrewrite/staticanalysis/FixStringFormatExpressions.java b/src/main/java/org/openrewrite/staticanalysis/FixStringFormatExpressions.java index a753295788..ecf22b8ac7 100644 --- a/src/main/java/org/openrewrite/staticanalysis/FixStringFormatExpressions.java +++ b/src/main/java/org/openrewrite/staticanalysis/FixStringFormatExpressions.java @@ -95,13 +95,12 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu argIndex++; } int finalArgIndex = argIndex; - mi = mi.withArguments(ListUtils.map(mi.getArguments(), (i, arg) -> { + return mi.withArguments(ListUtils.map(mi.getArguments(), (i, arg) -> { if (i == 0 || i < finalArgIndex) { return arg; } return null; })); - return mi; } return mi; } diff --git a/src/main/java/org/openrewrite/staticanalysis/ForLoopIncrementInUpdate.java b/src/main/java/org/openrewrite/staticanalysis/ForLoopIncrementInUpdate.java index 39e767cb79..e73aa41431 100644 --- a/src/main/java/org/openrewrite/staticanalysis/ForLoopIncrementInUpdate.java +++ b/src/main/java/org/openrewrite/staticanalysis/ForLoopIncrementInUpdate.java @@ -87,16 +87,13 @@ public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) { Comparator.comparing(s -> s.printTrimmed(getCursor()), Comparator.naturalOrder()) ))); - //noinspection ConstantConditions - f = f.withBody((Statement) new JavaVisitor() { + return f.withBody((Statement) new JavaVisitor() { @Override public @Nullable J visit(@Nullable Tree tree, ExecutionContext ctx) { return tree == unary ? null : super.visit(tree, ctx); } }.visit(f.getBody(), ctx)); - - return f; } } } diff --git a/src/main/java/org/openrewrite/staticanalysis/InlineVariable.java b/src/main/java/org/openrewrite/staticanalysis/InlineVariable.java index 5079b3d5ed..7136ac46e0 100644 --- a/src/main/java/org/openrewrite/staticanalysis/InlineVariable.java +++ b/src/main/java/org/openrewrite/staticanalysis/InlineVariable.java @@ -21,10 +21,8 @@ import org.openrewrite.TreeVisitor; import org.openrewrite.internal.ListUtils; import org.openrewrite.java.JavaIsoVisitor; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaType; -import org.openrewrite.java.tree.Statement; +import org.openrewrite.java.search.SemanticallyEqual; +import org.openrewrite.java.tree.*; import java.time.Duration; import java.util.Collections; @@ -42,7 +40,8 @@ public String getDisplayName() { @Override public String getDescription() { - return "Inline variables when they are immediately used to return or throw."; + return "Inline variables when they are immediately used to return or throw. " + + "Supports both variable declarations and assignments to local variables."; } @Override @@ -62,33 +61,31 @@ public TreeVisitor getVisitor() { public J.Block visitBlock(J.Block block, ExecutionContext ctx) { J.Block bl = super.visitBlock(block, ctx); List statements = bl.getStatements(); - if (statements.size() > 1) { - String identReturned = identReturned(statements); + if (1 < statements.size()) { + J.Identifier identReturned = identReturnedOrThrown(statements); if (identReturned != null) { - if (statements.get(statements.size() - 2) instanceof J.VariableDeclarations) { - J.VariableDeclarations varDec = (J.VariableDeclarations) statements.get(statements.size() - 2); - J.VariableDeclarations.NamedVariable identDefinition = varDec.getVariables().get(0); - if (varDec.getLeadingAnnotations().isEmpty() && identDefinition.getSimpleName().equals(identReturned)) { - bl = bl.withStatements(ListUtils.map(statements, (i, statement) -> { - if (i == statements.size() - 2) { - return null; - } - if (i == statements.size() - 1) { - if (statement instanceof J.Return) { - J.Return return_ = (J.Return) statement; - return return_.withExpression(requireNonNull(identDefinition.getInitializer()) - .withPrefix(requireNonNull(return_.getExpression()).getPrefix())) - .withPrefix(varDec.getPrefix().withComments(ListUtils.concatAll(varDec.getComments(), return_.getComments()))); - } - if (statement instanceof J.Throw) { - J.Throw thrown = (J.Throw) statement; - return thrown.withException(requireNonNull(identDefinition.getInitializer()) - .withPrefix(requireNonNull(thrown.getException()).getPrefix())) - .withPrefix(varDec.getPrefix().withComments(ListUtils.concatAll(varDec.getComments(), thrown.getComments()))); - } - } - return statement; - })); + Statement secondLastStatement = statements.get(statements.size() - 2); + if (secondLastStatement instanceof J.VariableDeclarations) { + J.VariableDeclarations varDec = (J.VariableDeclarations) secondLastStatement; + // Only inline if there's exactly one variable declared + if (varDec.getVariables().size() == 1) { + J.VariableDeclarations.NamedVariable identDefinition = varDec.getVariables().get(0); + if (varDec.getLeadingAnnotations().isEmpty() && + SemanticallyEqual.areEqual(identDefinition.getName(), identReturned)) { + return inlineExpression(identDefinition.getInitializer(), bl, statements, varDec.getPrefix(), varDec.getComments()); + } + } + } else if (secondLastStatement instanceof J.Assignment) { + J.Assignment assignment = (J.Assignment) secondLastStatement; + if (assignment.getVariable() instanceof J.Identifier) { + J.Identifier assignedVar = (J.Identifier) assignment.getVariable(); + // Only inline local variable assignments, not fields + if (assignedVar.getFieldType() != null && + assignedVar.getFieldType().getOwner() instanceof JavaType.Method && + SemanticallyEqual.areEqual(assignedVar, identReturned)) { + doAfterVisit(new RemoveUnusedLocalVariables(null, null, null).getVisitor()); + return inlineExpression(assignment.getAssignment(), bl, statements, assignment.getPrefix(), assignment.getComments()); + } } } } @@ -96,19 +93,47 @@ public J.Block visitBlock(J.Block block, ExecutionContext ctx) { return bl; } - private @Nullable String identReturned(List stats) { + private J.Block inlineExpression(@Nullable Expression expression, J.Block bl, List statements, + Space prefix, List comments) { + if (expression == null) { + return bl; + } + + return bl.withStatements(ListUtils.map(statements, (i, statement) -> { + if (i == statements.size() - 2) { + return null; + } + if (i == statements.size() - 1) { + if (statement instanceof J.Return) { + J.Return return_ = (J.Return) statement; + return return_ + .withExpression(expression.withPrefix(requireNonNull(return_.getExpression()).getPrefix())) + .withPrefix(prefix.withComments(ListUtils.concatAll(comments, return_.getComments()))); + } + if (statement instanceof J.Throw) { + J.Throw thrown = (J.Throw) statement; + return thrown. + withException(expression.withPrefix(requireNonNull(thrown.getException()).getPrefix())) + .withPrefix(prefix.withComments(ListUtils.concatAll(comments, thrown.getComments()))); + } + } + return statement; + })); + } + + private J.@Nullable Identifier identReturnedOrThrown(List stats) { Statement lastStatement = stats.get(stats.size() - 1); if (lastStatement instanceof J.Return) { J.Return return_ = (J.Return) lastStatement; Expression expression = return_.getExpression(); if (expression instanceof J.Identifier && - !(expression.getType() instanceof JavaType.Array)) { - return ((J.Identifier) expression).getSimpleName(); + !(expression.getType() instanceof JavaType.Array)) { + return ((J.Identifier) expression); } } else if (lastStatement instanceof J.Throw) { J.Throw thr = (J.Throw) lastStatement; if (thr.getException() instanceof J.Identifier) { - return ((J.Identifier) thr.getException()).getSimpleName(); + return ((J.Identifier) thr.getException()); } } return null; diff --git a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java index 36af559a6e..8b86e0066b 100644 --- a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java +++ b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java @@ -424,8 +424,7 @@ public J visitBinary(J.Binary binary, Integer p) { @Override public J.InstanceOf visitInstanceOf(J.InstanceOf instanceOf, Integer p) { instanceOf = (J.InstanceOf) super.visitInstanceOf(instanceOf, p); - instanceOf = replacements.processInstanceOf(instanceOf, getCursor()); - return instanceOf; + return replacements.processInstanceOf(instanceOf, getCursor()); } @Override diff --git a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java index 5c56e031d3..242689070b 100644 --- a/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java +++ b/src/main/java/org/openrewrite/staticanalysis/MinimumSwitchCases.java @@ -255,7 +255,6 @@ private boolean switchesOnEnum(J.Switch switch_) { } private static J.If createIfForEnum(Expression expression, Expression enumTree) { - J.If generatedIf; if (enumTree instanceof J.Identifier) { enumTree = new J.FieldAccess( randomId(), @@ -267,7 +266,7 @@ private static J.If createIfForEnum(Expression expression, Expression enumTree) ); } J.Binary ifCond = JavaElementFactory.newLogicalExpression(J.Binary.Type.Equal, expression, enumTree); - generatedIf = new J.If( + return new J.If( randomId(), Space.EMPTY, Markers.EMPTY, @@ -275,7 +274,6 @@ private static J.If createIfForEnum(Expression expression, Expression enumTree) JRightPadded.build(J.Block.createEmptyBlock()), null ); - return generatedIf; } @AllArgsConstructor diff --git a/src/main/java/org/openrewrite/staticanalysis/NoFinalizer.java b/src/main/java/org/openrewrite/staticanalysis/NoFinalizer.java index 5ec7d90400..d25c79d0db 100644 --- a/src/main/java/org/openrewrite/staticanalysis/NoFinalizer.java +++ b/src/main/java/org/openrewrite/staticanalysis/NoFinalizer.java @@ -59,7 +59,8 @@ public TreeVisitor getVisitor() { @Override public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, ExecutionContext ctx) { J.ClassDeclaration cd = super.visitClassDeclaration(classDecl, ctx); - cd = cd.withBody(cd.getBody().withStatements(ListUtils.map(cd.getBody().getStatements(), stmt -> { + + return cd.withBody(cd.getBody().withStatements(ListUtils.map(cd.getBody().getStatements(), stmt -> { if (stmt instanceof J.MethodDeclaration) { if (FINALIZER.matches((J.MethodDeclaration) stmt, classDecl)) { return null; @@ -67,8 +68,6 @@ public J.ClassDeclaration visitClassDeclaration(J.ClassDeclaration classDecl, Ex } return stmt; }))); - - return cd; } }); } diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveUnusedLocalVariables.java b/src/main/java/org/openrewrite/staticanalysis/RemoveUnusedLocalVariables.java index 3a969bee35..f3068f6a6d 100644 --- a/src/main/java/org/openrewrite/staticanalysis/RemoveUnusedLocalVariables.java +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveUnusedLocalVariables.java @@ -249,17 +249,14 @@ private static class PruneAssignmentExpression extends JavaIsoVisitor J.ControlParentheses visitControlParentheses(J.ControlParentheses c, ExecutionContext ctx) { - //noinspection unchecked - c = (J.ControlParentheses) new AssignmentToLiteral(assignment) + return (J.ControlParentheses) new AssignmentToLiteral(assignment) .visitNonNull(c, ctx, getCursor().getParentOrThrow()); - return c; } @Override public J.MethodInvocation visitMethodInvocation(J.MethodInvocation m, ExecutionContext ctx) { AssignmentToLiteral atl = new AssignmentToLiteral(assignment); - m = m.withArguments(ListUtils.map(m.getArguments(), it -> (Expression) atl.visitNonNull(it, ctx, getCursor().getParentOrThrow()))); - return m; + return m.withArguments(ListUtils.map(m.getArguments(), it -> (Expression) atl.visitNonNull(it, ctx, getCursor().getParentOrThrow()))); } } diff --git a/src/main/java/org/openrewrite/staticanalysis/TypecastParenPad.java b/src/main/java/org/openrewrite/staticanalysis/TypecastParenPad.java index c817989435..71356e7628 100755 --- a/src/main/java/org/openrewrite/staticanalysis/TypecastParenPad.java +++ b/src/main/java/org/openrewrite/staticanalysis/TypecastParenPad.java @@ -66,9 +66,8 @@ public J visit(@Nullable Tree tree, ExecutionContext ctx) { @Override public J.TypeCast visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { J.TypeCast tc = super.visitTypeCast(typeCast, ctx); - tc = (J.TypeCast) new SpacesVisitor<>(spacesStyle, null, null, tc) + return (J.TypeCast) new SpacesVisitor<>(spacesStyle, null, null, tc) .visitNonNull(tc, ctx, getCursor().getParentTreeCursor().fork()); - return tc; } } ); diff --git a/src/main/java/org/openrewrite/staticanalysis/UseAsBuilder.java b/src/main/java/org/openrewrite/staticanalysis/UseAsBuilder.java index faeb9677d9..e3e457052e 100644 --- a/src/main/java/org/openrewrite/staticanalysis/UseAsBuilder.java +++ b/src/main/java/org/openrewrite/staticanalysis/UseAsBuilder.java @@ -200,8 +200,7 @@ private J.VariableDeclarations consolidateBuilder(J.VariableDeclarations consoli ); }) ); - cb = formatTabsAndIndents(cb, getCursor()); - return cb; + return formatTabsAndIndents(cb, getCursor()); } }; diff --git a/src/main/java/org/openrewrite/staticanalysis/WhileInsteadOfFor.java b/src/main/java/org/openrewrite/staticanalysis/WhileInsteadOfFor.java index f0c1ef7046..3e9178d74a 100644 --- a/src/main/java/org/openrewrite/staticanalysis/WhileInsteadOfFor.java +++ b/src/main/java/org/openrewrite/staticanalysis/WhileInsteadOfFor.java @@ -61,8 +61,7 @@ public J visitForLoop(J.ForLoop forLoop, ExecutionContext ctx) { !(forLoop.getControl().getCondition() instanceof J.Empty) ) { J.WhileLoop w = whileLoop.apply(getCursor(), forLoop.getCoordinates().replace(), forLoop.getControl().getCondition()); - w = w.withBody(forLoop.getBody()); - return w; + return w.withBody(forLoop.getBody()); } return super.visitForLoop(forLoop, ctx); } diff --git a/src/main/java/org/openrewrite/staticanalysis/java/MoveFieldAnnotationToType.java b/src/main/java/org/openrewrite/staticanalysis/java/MoveFieldAnnotationToType.java index c4a4703132..84d98188e4 100644 --- a/src/main/java/org/openrewrite/staticanalysis/java/MoveFieldAnnotationToType.java +++ b/src/main/java/org/openrewrite/staticanalysis/java/MoveFieldAnnotationToType.java @@ -184,8 +184,7 @@ private TypeTree annotateInnerClass(TypeTree qualifiedClassRef, J.Annotation ann } if (qualifiedClassRef instanceof J.ArrayType) { J.ArrayType at = (J.ArrayType) qualifiedClassRef; - at = at.withAnnotations(ListUtils.concat(annotation.withPrefix(Space.SINGLE_SPACE), at.getAnnotations())); - return at; + return at.withAnnotations(ListUtils.concat(annotation.withPrefix(Space.SINGLE_SPACE), at.getAnnotations())); } return qualifiedClassRef; } diff --git a/src/test/java/org/openrewrite/staticanalysis/InlineVariableTest.java b/src/test/java/org/openrewrite/staticanalysis/InlineVariableTest.java index 74c54988c5..2353932a3c 100644 --- a/src/test/java/org/openrewrite/staticanalysis/InlineVariableTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/InlineVariableTest.java @@ -252,4 +252,403 @@ int[] test() { ) ); } + + @Test + void inlineAssignmentReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + String test() { + String result; + result = "hello"; + return result; + } + } + """, + """ + class Test { + String test() { + return "hello"; + } + } + """ + ) + ); + } + + @Test + void inlineAssignmentThrow() { + rewriteRun( + //language=java + java( + """ + class Test { + void test() { + RuntimeException e; + e = new RuntimeException("error"); + throw e; + } + } + """, + """ + class Test { + void test() { + throw new RuntimeException("error"); + } + } + """ + ) + ); + } + + @Test + void inlineComplexAssignmentReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(String input) { + String result; + result = input.trim().toUpperCase(); + return result; + } + } + """, + """ + class Test { + String test(String input) { + return input.trim().toUpperCase(); + } + } + """ + ) + ); + } + + @Test + void doNotInlineFieldAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + private String field; + + String test() { + field = "value"; + return field; + } + } + """ + ) + ); + } + + + @Test + void preserveCommentsOnAssignment() { + rewriteRun( + //language=java + java( + """ + class Test { + String test() { + String result; + // important assignment + result = "value"; // trailing comment + return result; + } + } + """, + """ + class Test { + String test() { + // important assignment + // trailing comment + return "value"; + } + } + """ + ) + ); + } + + @Test + void doNotInlineWhenMultipleVariables() { + rewriteRun( + //language=java + java( + """ + class Test { + String getString() { + String a = "Hello", b = "World"; + return a; + } + } + """ + ) + ); + } + + @Test + void inlineAssignmentToParameter() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(String param) { + param = toString(); + return param; + } + } + """, + """ + class Test { + String test(String param) { + return toString(); + } + } + """ + ) + ); + } + + @Test + void inlineMultipleAssignmentsBeforeReturn() { + rewriteRun( + //language=java + java( + """ + class Test { + String test() { + String variable = null; + variable = toString(); + return variable; + } + } + """, + """ + class Test { + String test() { + return toString(); + } + } + """ + ) + ); + } + + @Test + void inlineAssignmentWithMethodCall() { + rewriteRun( + //language=java + java( + """ + class Test { + String spy() { return "spy"; } + + String test() { + String variable = spy(); + return variable; + } + } + """, + """ + class Test { + String spy() { return "spy"; } + + String test() { + return spy(); + } + } + """ + ) + ); + } + + @Test + void inlineAssignmentInElseBlockWhilePreservingIfBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(boolean condition) { + String variable = toString(); + if (condition) { + return variable; + } else { + variable = "foo"; + return variable; + } + } + } + """, + """ + class Test { + String test(boolean condition) { + String variable = toString(); + if (condition) { + return variable; + } else { + return "foo"; + } + } + } + """ + ) + ); + } + + @Test + void inlineVariableInElseBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(boolean condition) { + if (condition) { + return "bar"; + } else { + String variable = "foo"; + return variable; + } + } + } + """, + """ + class Test { + String test(boolean condition) { + if (condition) { + return "bar"; + } else { + return "foo"; + } + } + } + """ + ) + ); + } + + @Test + void inlineAssignmentInTryBlock() { + rewriteRun( + //language=java + java( + """ + class Test { + String test() { + try { + String result = someMethod(); + return result; + } catch (Exception e) { + return null; + } + } + + String someMethod() throws Exception { + return "value"; + } + } + """, + """ + class Test { + String test() { + try { + return someMethod(); + } catch (Exception e) { + return null; + } + } + + String someMethod() throws Exception { + return "value"; + } + } + """ + ) + ); + } + + @Test + void inlineVariableWithCast() { + rewriteRun( + //language=java + java( + """ + class Test { + Object test() { + String str = (String) getObject(); + return str; + } + + Object getObject() { + return "string"; + } + } + """, + """ + class Test { + Object test() { + return (String) getObject(); + } + + Object getObject() { + return "string"; + } + } + """ + ) + ); + } + + @Test + void inlineVariableWithTernary() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(boolean flag) { + String result = flag ? "yes" : "no"; + return result; + } + } + """, + """ + class Test { + String test(boolean flag) { + return flag ? "yes" : "no"; + } + } + """ + ) + ); + } + + @Test + void doNotInlineWhenVariableIsReassigned() { + rewriteRun( + //language=java + java( + """ + class Test { + String test(boolean condition) { + String result = "initial"; + if (condition) { + result = "changed"; + } + return result; + } + } + """ + ) + ); + } }