Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) {
JavaType expressionType = visitedTypeCast.getExpression().getType();
JavaType castType = visitedTypeCast.getType();

if (targetType == null || targetType instanceof JavaType.Primitive && castType != expressionType) {
if (targetType == null ||
targetType instanceof JavaType.Primitive && castType != expressionType ||
typeCast.getExpression() instanceof J.Lambda && castType instanceof JavaType.Parameterized) {
// Not currently supported, this will be more accurate with dataflow analysis.
return visitedTypeCast;
} else if (!(targetType instanceof JavaType.Array) && TypeUtils.isOfClassType(targetType, "java.lang.Object") ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,35 +132,17 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) {
return n;
}

@Nullable
private JavaType.Method getSamCompatible(JavaType type) {
JavaType.Method sam = null;
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type);
if (fullyQualified == null) {
return null;
}
for (JavaType.Method method : fullyQualified.getMethods()) {
if (method.hasFlags(Flag.Default) || method.hasFlags(Flag.Static)) {
continue;
}
if (sam != null) {
return null;
}
sam = method;
}
return sam;
}

private J maybeAddCast(J.Lambda lambda, J.NewClass original) {
J parent = getCursor().getParentTreeCursor().getValue();

if (parent instanceof J.MethodInvocation) {
J.MethodInvocation method = (J.MethodInvocation) parent;
if (parent instanceof MethodCall) {
MethodCall method = (MethodCall) parent;
List<Expression> arguments = method.getArguments();
for (int i = 0; i < arguments.size(); i++) {
Expression argument = arguments.get(i);
if (argument == original && methodArgumentRequiresCast(lambda, method, i) &&
original.getClazz() != null) {
doAfterVisit(new RemoveRedundantTypeCast().getVisitor());
return new J.TypeCast(
Tree.randomId(),
lambda.getPrefix(),
Expand All @@ -180,7 +162,7 @@ private J maybeAddCast(J.Lambda lambda, J.NewClass original) {
return lambda;
}

private boolean methodArgumentRequiresCast(J.Lambda lambda, J.MethodInvocation method, int argumentIndex) {
private boolean methodArgumentRequiresCast(J.Lambda lambda, MethodCall method, int argumentIndex) {
JavaType.FullyQualified lambdaType = TypeUtils.asFullyQualified(lambda.getType());
if (lambdaType == null) {
return false;
Expand All @@ -207,7 +189,11 @@ private boolean methodArgumentRequiresCast(J.Lambda lambda, J.MethodInvocation m
}
}
}
return count >= 2;
if (count >= 2) {
return true;
}

return hasGenerics(lambda);
}

private boolean areMethodsAmbiguous(@Nullable JavaType.Method m1, @Nullable JavaType.Method m2) {
Expand Down Expand Up @@ -419,4 +405,42 @@ public J visitVariable(J.VariableDeclarations.NamedVariable variable, Integer in

return hasShadow.get();
}

private static boolean hasGenerics(J.Lambda lambda) {
AtomicBoolean atomicBoolean = new AtomicBoolean();
new JavaVisitor<AtomicBoolean>() {
@Override
public J visitMethodInvocation(J.MethodInvocation method, AtomicBoolean atomicBoolean) {
if (method.getMethodType() != null &&
method.getMethodType().getParameterTypes().stream()
.anyMatch(p -> p instanceof JavaType.Parameterized &&
((JavaType.Parameterized) p).getTypeParameters().stream().anyMatch(t -> t instanceof JavaType.GenericTypeVariable))
) {
atomicBoolean.set(true);
}
return super.visitMethodInvocation(method, atomicBoolean);
}
}.visit(lambda.getBody(), atomicBoolean);
return atomicBoolean.get();
}

// TODO consider moving to TypeUtils
@Nullable
private static JavaType.Method getSamCompatible(@Nullable JavaType type) {
JavaType.Method sam = null;
JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type);
if (fullyQualified == null) {
return null;
}
for (JavaType.Method method : fullyQualified.getMethods()) {
if (method.hasFlags(Flag.Default) || method.hasFlags(Flag.Static)) {
continue;
}
if (sam != null) {
return null;
}
sam = method;
}
return sam;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,24 @@ public List<String> get() {
);
}

@Test
void nonSamParameter() {
rewriteRun(
//language=java
java(
"""
import java.util.*;

class Test {
public boolean foo() {
return Objects.equals("x", (Comparable<String>) (s) -> 1);
}
}
"""
)
);
}

@Issue("https://github.com/openrewrite/rewrite/issues/1647")
@Test
void redundantTypeCast() {
Expand Down Expand Up @@ -374,8 +392,7 @@ class ExtendTest extends Test {


@Test
@Issue("https://github.com/moderneinc/support-app/issues/17")
void test() {
void lambdaWithComplexTypeInference() {
rewriteRun(
java(
"""
Expand All @@ -390,29 +407,8 @@ void method() {
(Supplier<Map<String, Integer>>) () -> {
Map<String, Integer> choices = Map.of("id1", 2);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
});
}
}

class MapDropdownChoice<K, V> {
public MapDropdownChoice(Supplier<? extends Map<K, ? extends V>> choiceMap) {
}
}
""",
"""
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

class Test {
void method() {
Object o2 = new MapDropdownChoice<String, Integer>(
() -> {
Map<String, Integer> choices = Map.of("id1", 2);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
(e1, e2) -> e1, LinkedHashMap::new));
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -664,4 +664,75 @@ public void run() {
)
);
}

@Test
@Issue("https://github.com/moderneinc/support-app/issues/17")
void lambdaWithComplexTypeInference() {
rewriteRun(
java(
"""
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

class Test {
void method() {
Object o = new MapDropdownChoice<String, Integer>(
new Supplier<Map<String, Integer>>() {
@Override
public Map<String, Integer> getObject() {
Map<String, Integer> choices = Map.of("id1", 1);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
}
});
Object o2 = new MapDropdownChoice<String, Integer>(
new Supplier<Map<String, Integer>>() {
@Override
public Map<String, Integer> getObject() {
Map<String, Integer> choices = Map.of("id1", 2);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}
});
}
}

class MapDropdownChoice<K, V> {
public MapDropdownChoice(Supplier<? extends Map<K, ? extends V>> choiceMap) {
}
}
""",
"""
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

class Test {
void method() {
Object o = new MapDropdownChoice<String, Integer>(
(Supplier<Map<String, Integer>>) () -> {
Map<String, Integer> choices = Map.of("id1", 1);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
});
Object o2 = new MapDropdownChoice<String, Integer>(
(Supplier<Map<String, Integer>>) () -> {
Map<String, Integer> choices = Map.of("id1", 2);
return choices.entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
});
}
}

class MapDropdownChoice<K, V> {
public MapDropdownChoice(Supplier<? extends Map<K, ? extends V>> choiceMap) {
}
}
"""
)
);
}
}