diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java b/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java index 774c63000f..52e9dd2adb 100644 --- a/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java @@ -19,12 +19,9 @@ import org.jspecify.annotations.Nullable; import org.openrewrite.java.JavaIsoVisitor; import org.openrewrite.java.MethodMatcher; -import org.openrewrite.java.tree.Expression; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.MethodCall; +import org.openrewrite.java.tree.*; import java.util.function.BiPredicate; -import java.util.function.Supplier; /** * Removes all {@link MethodCall} matching both the @@ -48,33 +45,48 @@ public class RemoveMethodCallVisitor

extends JavaIsoVisitor

{ @SuppressWarnings("NullableProblems") @Override public J.@Nullable NewClass visitNewClass(J.NewClass newClass, P p) { - return visitMethodCall(newClass, () -> super.visitNewClass(newClass, p)); + if (methodMatcher.matches(newClass) && predicateMatchesAllArguments(newClass) && isStatementInParentBlock(newClass)) { + if (newClass.getMethodType() != null) { + maybeRemoveImport(newClass.getMethodType().getDeclaringType()); + } + return null; + } + return super.visitNewClass(newClass, p); } @SuppressWarnings("NullableProblems") @Override public J.@Nullable MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) { - return visitMethodCall(method, () -> super.visitMethodInvocation(method, p)); - } + // Find method invocations that match the specified method and arguments + if (methodMatcher.matches(method) && predicateMatchesAllArguments(method)) { + // If the method invocation is a standalone statement, remove it altogether + if (isStatementInParentBlock(method)) { + if (method.getMethodType() != null) { + maybeRemoveImport(method.getMethodType().getDeclaringType()); + } + return null; + } - private @Nullable M visitMethodCall(M methodCall, Supplier visitSuper) { - if (!methodMatcher.matches(methodCall)) { - return visitSuper.get(); - } - J.Block parentBlock = getCursor().firstEnclosing(J.Block.class); - //noinspection SuspiciousMethodCalls - if (parentBlock != null && !parentBlock.getStatements().contains(methodCall)) { - return visitSuper.get(); - } - // Remove the method invocation when the argumentMatcherPredicate is true for all arguments - for (int i = 0; i < methodCall.getArguments().size(); i++) { - if (!argumentPredicate.test(i, methodCall.getArguments().get(i))) { - return visitSuper.get(); + // If the method invocation is in a fluent chain, remove just the current invocation + if (method.getSelect() instanceof J.MethodInvocation && + TypeUtils.isOfType(method.getType(), method.getSelect().getType())) { + return super.visitMethodInvocation((J.MethodInvocation) method.getSelect(), p); } } - if (methodCall.getMethodType() != null) { - maybeRemoveImport(methodCall.getMethodType().getDeclaringType()); + return super.visitMethodInvocation(method, p); + } + + private boolean predicateMatchesAllArguments(MethodCall method) { + for (int i = 0; i < method.getArguments().size(); i++) { + if (!argumentPredicate.test(i, method.getArguments().get(i))) { + return false; + } } - return null; + return true; + } + + private boolean isStatementInParentBlock(Statement method) { + J.Block parentBlock = getCursor().firstEnclosing(J.Block.class); + return parentBlock == null || parentBlock.getStatements().contains(method); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java index a8ad6a73f3..21ac36efe4 100644 --- a/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java @@ -43,7 +43,7 @@ void assertTrueIsRemoved() { """ abstract class Test { abstract void assertTrue(boolean condition); - + void test() { System.out.println("Hello"); assertTrue(true); @@ -53,7 +53,7 @@ void test() { """, """ abstract class Test { abstract void assertTrue(boolean condition); - + void test() { System.out.println("Hello"); System.out.println("World"); @@ -72,7 +72,7 @@ void assertTrueFalseIsNotRemoved() { """ abstract class Test { abstract void assertTrue(boolean condition); - + void test() { System.out.println("Hello"); assertTrue(false); @@ -95,7 +95,7 @@ void assertTrueTwoArgIsRemoved() { """ abstract class Test { abstract void assertTrue(String message, boolean condition); - + void test() { System.out.println("Hello"); assertTrue("message", true); @@ -106,7 +106,7 @@ void test() { """ abstract class Test { abstract void assertTrue(String message, boolean condition); - + void test() { System.out.println("Hello"); System.out.println("World"); @@ -125,7 +125,7 @@ void doesNotRemoveAssertTrueIfReturnValueIsUsed() { """ abstract class Test { abstract int assertTrue(boolean condition); - + void test() { System.out.println("Hello"); int value = assertTrue(true); @@ -136,4 +136,34 @@ void test() { ) ); } + + @Test + void removeMethodCallFromFluentChain() { + rewriteRun( + spec -> spec.recipe(RewriteTest.toRecipe(() -> new RemoveMethodCallVisitor<>( + new MethodMatcher("java.lang.StringBuilder append(..)"), (i, e) -> true))), + // language=java + java( + """ + class Main { + void hello() { + final String s = new StringBuilder("hello") + .delete(1, 2) + .append("world") + .toString(); + } + } + """, + """ + class Main { + void hello() { + final String s = new StringBuilder("hello") + .delete(1, 2) + .toString(); + } + } + """ + ) + ); + } }