Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
}

private static class ReplaceLambdaWithMethodReferenceKotlinVisitor extends KotlinVisitor<ExecutionContext> {
// Implement Me
// XXX Implement Me
}

private static class ReplaceLambdaWithMethodReferenceJavaVisitor extends JavaVisitor<ExecutionContext> {
Expand Down Expand Up @@ -195,6 +195,10 @@ public J visitLambda(J.Lambda lambda, ExecutionContext ctx) {
.anyMatch(JavaType.GenericTypeVariable.class::isInstance)) {
return l;
}
// Check if transforming would break type inference in nested generic/overloaded context
if (isLambdaInGenericAndOverloadedContext()) {
return l;
}
J.MemberReference updated = newStaticMethodReference(methodType, true, lambda.getType()).withPrefix(lambda.getPrefix());
doAfterVisit(service(ImportService.class).shortenFullyQualifiedTypeReferencesIn(updated));
return updated;
Expand Down Expand Up @@ -358,10 +362,68 @@ private boolean isMethodReferenceAmbiguous(JavaType.Method method) {
}
return false;
}

/**
* Check if the lambda is in a context where converting it to a method reference
* would break type inference. This occurs when a lambda's return type depends on
* generic type inference, and the lambda is passed through method calls where
* one of the enclosing methods is overloaded.
* <p>
* Example: foo(fold(() -> Optional.empty()))
* where fold is generic and foo is overloaded.
*/
private boolean isLambdaInGenericAndOverloadedContext() {
// Walk up the cursor tree to find enclosing method invocations
Cursor cursor = getCursor();

// Find the first method invocation that the lambda is an argument to
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
if (!(parent.getValue() instanceof J.MethodInvocation)) {
return false;
}

J.MethodInvocation innerMethod = parent.getValue();

// Now check if there's an enclosing overloaded method where innerMethod is an argument
// Start from the parent of the first method invocation
Cursor grandparent = parent.getParent();
if (grandparent == null) {
return false;
}

grandparent = grandparent.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
if (!(grandparent.getValue() instanceof J.MethodInvocation)) {
return false;
}

J.MethodInvocation outerMethod = grandparent.getValue();

// Check that innerMethod is actually an ARGUMENT to outerMethod, not just chained
boolean isInnerMethodAnArgument = outerMethod.getArguments().stream()
.anyMatch(arg -> arg == innerMethod);

if (!isInnerMethodAnArgument) {
return false;
}

JavaType.Method outerType = outerMethod.getMethodType();
if (outerType == null) {
return false;
}

// Check if the outer method is overloaded
long overloadCount = outerType.getDeclaringType().getMethods().stream()
.filter(m -> m.getName().equals(outerType.getName()) && !m.isConstructor())
.count();

// If we have nested method calls where the outer one is overloaded,
// be conservative and don't transform to avoid breaking type inference
return overloadCount > 1;
}
}

private static boolean isAMethodInvocationArgument(J.Lambda lambda, Cursor cursor) {
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof J.CompilationUnit);
Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof SourceFile);
if (parent.getValue() instanceof J.MethodInvocation) {
J.MethodInvocation m = parent.getValue();
return m.getArguments().stream().anyMatch(arg -> arg == lambda);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1576,4 +1576,64 @@ private void foo() {
)
);
}

@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/774")
@Test
void methodRefWithGenerics() {
rewriteRun(
//language=java
java(
"""
import java.util.Optional;
import java.util.function.Supplier;

class Foo {
<R> R fold(final Supplier<R> supplier) {return null;}

void foo(String l) {}
void foo(Optional<String> l) {}

void bar() {
foo(fold(() -> Optional.empty()));
}
}
"""
)
);
}

@Test
void simpleGenericMethodTest() {
rewriteRun(
//language=java
java(
"""
import java.util.function.Supplier;

class Foo {
<R> R fold(final Supplier<R> supplier) {return null;}

String getString() { return "test"; }

void bar() {
String result = fold(() -> getString());
}
}
""",
"""
import java.util.function.Supplier;

class Foo {
<R> R fold(final Supplier<R> supplier) {return null;}

String getString() { return "test"; }

void bar() {
String result = fold(this::getString);
}
}
"""
)
);
}
}