diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java b/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java
index 774c63000f..52e9dd2adb 100644
--- a/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java
+++ b/src/main/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitor.java
@@ -19,12 +19,9 @@
import org.jspecify.annotations.Nullable;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.MethodMatcher;
-import org.openrewrite.java.tree.Expression;
-import org.openrewrite.java.tree.J;
-import org.openrewrite.java.tree.MethodCall;
+import org.openrewrite.java.tree.*;
import java.util.function.BiPredicate;
-import java.util.function.Supplier;
/**
* Removes all {@link MethodCall} matching both the
@@ -48,33 +45,48 @@ public class RemoveMethodCallVisitor
extends JavaIsoVisitor
{
@SuppressWarnings("NullableProblems")
@Override
public J.@Nullable NewClass visitNewClass(J.NewClass newClass, P p) {
- return visitMethodCall(newClass, () -> super.visitNewClass(newClass, p));
+ if (methodMatcher.matches(newClass) && predicateMatchesAllArguments(newClass) && isStatementInParentBlock(newClass)) {
+ if (newClass.getMethodType() != null) {
+ maybeRemoveImport(newClass.getMethodType().getDeclaringType());
+ }
+ return null;
+ }
+ return super.visitNewClass(newClass, p);
}
@SuppressWarnings("NullableProblems")
@Override
public J.@Nullable MethodInvocation visitMethodInvocation(J.MethodInvocation method, P p) {
- return visitMethodCall(method, () -> super.visitMethodInvocation(method, p));
- }
+ // Find method invocations that match the specified method and arguments
+ if (methodMatcher.matches(method) && predicateMatchesAllArguments(method)) {
+ // If the method invocation is a standalone statement, remove it altogether
+ if (isStatementInParentBlock(method)) {
+ if (method.getMethodType() != null) {
+ maybeRemoveImport(method.getMethodType().getDeclaringType());
+ }
+ return null;
+ }
- private @Nullable M visitMethodCall(M methodCall, Supplier visitSuper) {
- if (!methodMatcher.matches(methodCall)) {
- return visitSuper.get();
- }
- J.Block parentBlock = getCursor().firstEnclosing(J.Block.class);
- //noinspection SuspiciousMethodCalls
- if (parentBlock != null && !parentBlock.getStatements().contains(methodCall)) {
- return visitSuper.get();
- }
- // Remove the method invocation when the argumentMatcherPredicate is true for all arguments
- for (int i = 0; i < methodCall.getArguments().size(); i++) {
- if (!argumentPredicate.test(i, methodCall.getArguments().get(i))) {
- return visitSuper.get();
+ // If the method invocation is in a fluent chain, remove just the current invocation
+ if (method.getSelect() instanceof J.MethodInvocation &&
+ TypeUtils.isOfType(method.getType(), method.getSelect().getType())) {
+ return super.visitMethodInvocation((J.MethodInvocation) method.getSelect(), p);
}
}
- if (methodCall.getMethodType() != null) {
- maybeRemoveImport(methodCall.getMethodType().getDeclaringType());
+ return super.visitMethodInvocation(method, p);
+ }
+
+ private boolean predicateMatchesAllArguments(MethodCall method) {
+ for (int i = 0; i < method.getArguments().size(); i++) {
+ if (!argumentPredicate.test(i, method.getArguments().get(i))) {
+ return false;
+ }
}
- return null;
+ return true;
+ }
+
+ private boolean isStatementInParentBlock(Statement method) {
+ J.Block parentBlock = getCursor().firstEnclosing(J.Block.class);
+ return parentBlock == null || parentBlock.getStatements().contains(method);
}
}
diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java
index a8ad6a73f3..21ac36efe4 100644
--- a/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java
+++ b/src/test/java/org/openrewrite/staticanalysis/RemoveMethodCallVisitorTest.java
@@ -43,7 +43,7 @@ void assertTrueIsRemoved() {
"""
abstract class Test {
abstract void assertTrue(boolean condition);
-
+
void test() {
System.out.println("Hello");
assertTrue(true);
@@ -53,7 +53,7 @@ void test() {
""", """
abstract class Test {
abstract void assertTrue(boolean condition);
-
+
void test() {
System.out.println("Hello");
System.out.println("World");
@@ -72,7 +72,7 @@ void assertTrueFalseIsNotRemoved() {
"""
abstract class Test {
abstract void assertTrue(boolean condition);
-
+
void test() {
System.out.println("Hello");
assertTrue(false);
@@ -95,7 +95,7 @@ void assertTrueTwoArgIsRemoved() {
"""
abstract class Test {
abstract void assertTrue(String message, boolean condition);
-
+
void test() {
System.out.println("Hello");
assertTrue("message", true);
@@ -106,7 +106,7 @@ void test() {
"""
abstract class Test {
abstract void assertTrue(String message, boolean condition);
-
+
void test() {
System.out.println("Hello");
System.out.println("World");
@@ -125,7 +125,7 @@ void doesNotRemoveAssertTrueIfReturnValueIsUsed() {
"""
abstract class Test {
abstract int assertTrue(boolean condition);
-
+
void test() {
System.out.println("Hello");
int value = assertTrue(true);
@@ -136,4 +136,34 @@ void test() {
)
);
}
+
+ @Test
+ void removeMethodCallFromFluentChain() {
+ rewriteRun(
+ spec -> spec.recipe(RewriteTest.toRecipe(() -> new RemoveMethodCallVisitor<>(
+ new MethodMatcher("java.lang.StringBuilder append(..)"), (i, e) -> true))),
+ // language=java
+ java(
+ """
+ class Main {
+ void hello() {
+ final String s = new StringBuilder("hello")
+ .delete(1, 2)
+ .append("world")
+ .toString();
+ }
+ }
+ """,
+ """
+ class Main {
+ void hello() {
+ final String s = new StringBuilder("hello")
+ .delete(1, 2)
+ .toString();
+ }
+ }
+ """
+ )
+ );
+ }
}