diff --git a/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java b/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java index ae56e951fe..467bc140e5 100644 --- a/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java +++ b/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java @@ -159,6 +159,7 @@ public J visitLambda(J.Lambda lambda, ExecutionContext ctx) { } if (hasSelectWithPotentialSideEffects(method) || + hasSelectWhoseReferenceMightChange(method) || !methodArgumentsMatchLambdaParameters(method, lambda) || method instanceof J.MemberReference) { return l; @@ -212,6 +213,14 @@ private boolean hasSelectWithPotentialSideEffects(MethodCall method) { ((J.MethodInvocation) method).getSelect() instanceof MethodCall; } + private boolean hasSelectWhoseReferenceMightChange(MethodCall method) { + if (method instanceof J.MethodInvocation && ((J.MethodInvocation) method).getSelect() instanceof J.Identifier) { + JavaType.Variable fieldType = ((J.Identifier) ((J.MethodInvocation) method).getSelect()).getFieldType(); + return fieldType != null && fieldType.getOwner() instanceof JavaType.Class && !fieldType.hasFlags(Flag.Final); + } + return false; + } + private boolean methodArgumentsMatchLambdaParameters(MethodCall method, J.Lambda lambda) { JavaType.Method methodType = method.getMethodType(); if (methodType == null) { diff --git a/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java b/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java index a16f2585f0..9b764f038d 100644 --- a/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java @@ -63,7 +63,8 @@ void multipleMethodInvocations() { """ import java.nio.file.Path; import java.nio.file.Paths; - import java.util.List;import java.util.stream.Collectors; + import java.util.List; + import java.util.stream.Collectors; class Test { Path path = Paths.get(""); @@ -190,6 +191,7 @@ List method(List input) { } } """, + //language=java """ import java.util.List; import java.util.stream.Collectors; @@ -237,6 +239,7 @@ List method(List input) { } } """, + //language=java """ import org.test.CheckType; @@ -288,6 +291,7 @@ Stream method(List input) { } } """, + //language=java """ import java.util.List; import java.util.stream.Stream; @@ -409,8 +413,10 @@ public void run() { Collections.singletonList(1).forEach(n -> run()); } } - Test t = new Test(); - Runnable r = () -> t.run(); + void foo() { + Test t = new Test(); + Runnable r = () -> t.run(); + } } """, """ @@ -423,8 +429,10 @@ public void run() { Collections.singletonList(1).forEach(n -> run()); } } - Test t = new Test(); - Runnable r = t::run; + void foo() { + Test t = new Test(); + Runnable r = t::run; + } } """ ) @@ -595,6 +603,7 @@ List filter(List l) { } } """, + //language=java """ import org.test.CheckType; @@ -1293,8 +1302,8 @@ private Integer foo(String bar) { @Test void newClassSelector() { + //language=java rewriteRun( - //language=java java( """ class A { @@ -1345,4 +1354,22 @@ public void groupOnGetClass() { ); } + @Test + void dontReplaceNullableFieldReferences() { + //language=java + rewriteRun( + java( + """ + import java.util.function.Supplier; + class A { + Object field; + void foo() { + // Runtime exception when replaced with field::toString + Supplier supplier = () -> field.toString(); + } + } + """ + ) + ); + } }