diff --git a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java index 6e6c81444a..ab73be5071 100644 --- a/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java +++ b/src/main/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCast.java @@ -101,7 +101,9 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { JavaType expressionType = visitedTypeCast.getExpression().getType(); JavaType castType = visitedTypeCast.getType(); - if (targetType == null || targetType instanceof JavaType.Primitive && castType != expressionType) { + if (targetType == null || + targetType instanceof JavaType.Primitive && castType != expressionType || + typeCast.getExpression() instanceof J.Lambda && castType instanceof JavaType.Parameterized) { // Not currently supported, this will be more accurate with dataflow analysis. return visitedTypeCast; } else if (!(targetType instanceof JavaType.Array) && TypeUtils.isOfClassType(targetType, "java.lang.Object") || diff --git a/src/main/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterface.java b/src/main/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterface.java index 6e65b2d895..843f883b03 100644 --- a/src/main/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterface.java +++ b/src/main/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterface.java @@ -132,30 +132,11 @@ public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) { return n; } - @Nullable - private JavaType.Method getSamCompatible(JavaType type) { - JavaType.Method sam = null; - JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type); - if (fullyQualified == null) { - return null; - } - for (JavaType.Method method : fullyQualified.getMethods()) { - if (method.hasFlags(Flag.Default) || method.hasFlags(Flag.Static)) { - continue; - } - if (sam != null) { - return null; - } - sam = method; - } - return sam; - } - private J maybeAddCast(J.Lambda lambda, J.NewClass original) { J parent = getCursor().getParentTreeCursor().getValue(); - if (parent instanceof J.MethodInvocation) { - J.MethodInvocation method = (J.MethodInvocation) parent; + if (parent instanceof MethodCall) { + MethodCall method = (MethodCall) parent; List arguments = method.getArguments(); for (int i = 0; i < arguments.size(); i++) { Expression argument = arguments.get(i); @@ -180,7 +161,7 @@ private J maybeAddCast(J.Lambda lambda, J.NewClass original) { return lambda; } - private boolean methodArgumentRequiresCast(J.Lambda lambda, J.MethodInvocation method, int argumentIndex) { + private boolean methodArgumentRequiresCast(J.Lambda lambda, MethodCall method, int argumentIndex) { JavaType.FullyQualified lambdaType = TypeUtils.asFullyQualified(lambda.getType()); if (lambdaType == null) { return false; @@ -207,7 +188,11 @@ private boolean methodArgumentRequiresCast(J.Lambda lambda, J.MethodInvocation m } } } - return count >= 2; + if (count >= 2) { + return true; + } + + return hasGenerics(lambda); } private boolean areMethodsAmbiguous(@Nullable JavaType.Method m1, @Nullable JavaType.Method m2) { @@ -419,4 +404,42 @@ public J visitVariable(J.VariableDeclarations.NamedVariable variable, Integer in return hasShadow.get(); } + + private static boolean hasGenerics(J.Lambda lambda) { + AtomicBoolean atomicBoolean = new AtomicBoolean(); + new JavaVisitor() { + @Override + public J visitMethodInvocation(J.MethodInvocation method, AtomicBoolean atomicBoolean) { + if (method.getMethodType() != null && + method.getMethodType().getParameterTypes().stream() + .anyMatch(p -> p instanceof JavaType.Parameterized && + ((JavaType.Parameterized) p).getTypeParameters().stream().anyMatch(t -> t instanceof JavaType.GenericTypeVariable)) + ) { + atomicBoolean.set(true); + } + return super.visitMethodInvocation(method, atomicBoolean); + } + }.visit(lambda.getBody(), atomicBoolean); + return atomicBoolean.get(); + } + + // TODO consider moving to TypeUtils + @Nullable + private static JavaType.Method getSamCompatible(@Nullable JavaType type) { + JavaType.Method sam = null; + JavaType.FullyQualified fullyQualified = TypeUtils.asFullyQualified(type); + if (fullyQualified == null) { + return null; + } + for (JavaType.Method method : fullyQualified.getMethods()) { + if (method.hasFlags(Flag.Default) || method.hasFlags(Flag.Static)) { + continue; + } + if (sam != null) { + return null; + } + sam = method; + } + return sam; + } } diff --git a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java index 8ec941d3c6..e2466561f5 100644 --- a/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/RemoveRedundantTypeCastTest.java @@ -154,6 +154,24 @@ public List get() { ); } + @Test + void nonSamParameter() { + rewriteRun( + //language=java + java( + """ + import java.util.*; + + class Test { + public boolean foo() { + return Objects.equals("x", (Comparable) (s) -> 1); + } + } + """ + ) + ); + } + @Issue("https://github.com/openrewrite/rewrite/issues/1647") @Test void redundantTypeCast() { @@ -374,8 +392,7 @@ class ExtendTest extends Test { @Test - @Issue("https://github.com/moderneinc/support-app/issues/17") - void test() { + void lambdaWithComplexTypeInference() { rewriteRun( java( """ @@ -390,29 +407,8 @@ void method() { (Supplier>) () -> { Map choices = Map.of("id1", 2); return choices.entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - }); - } - } - - class MapDropdownChoice { - public MapDropdownChoice(Supplier> choiceMap) { - } - } - """, - """ - import java.util.LinkedHashMap; - import java.util.Map; - import java.util.function.Supplier; - import java.util.stream.Collectors; - - class Test { - void method() { - Object o2 = new MapDropdownChoice( - () -> { - Map choices = Map.of("id1", 2); - return choices.entrySet().stream() - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, + (e1, e2) -> e1, LinkedHashMap::new)); }); } } diff --git a/src/test/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterfaceTest.java b/src/test/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterfaceTest.java index a74288059c..991bca6da6 100644 --- a/src/test/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterfaceTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/UseLambdaForFunctionalInterfaceTest.java @@ -664,4 +664,75 @@ public void run() { ) ); } + + @Test + @Issue("https://github.com/moderneinc/support-app/issues/17") + void lambdaWithComplexTypeInference() { + rewriteRun( + java( + """ + import java.util.LinkedHashMap; + import java.util.Map; + import java.util.function.Supplier; + import java.util.stream.Collectors; + + class Test { + void method() { + Object o = new MapDropdownChoice( + new Supplier>() { + @Override + public Map getObject() { + Map choices = Map.of("id1", 1); + return choices.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new)); + } + }); + Object o2 = new MapDropdownChoice( + new Supplier>() { + @Override + public Map getObject() { + Map choices = Map.of("id1", 2); + return choices.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + }); + } + } + + class MapDropdownChoice { + public MapDropdownChoice(Supplier> choiceMap) { + } + } + """, + """ + import java.util.LinkedHashMap; + import java.util.Map; + import java.util.function.Supplier; + import java.util.stream.Collectors; + + class Test { + void method() { + Object o = new MapDropdownChoice( + (Supplier>) () -> { + Map choices = Map.of("id1", 1); + return choices.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new)); + }); + Object o2 = new MapDropdownChoice( + (Supplier>) () -> { + Map choices = Map.of("id1", 2); + return choices.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + }); + } + } + + class MapDropdownChoice { + public MapDropdownChoice(Supplier> choiceMap) { + } + } + """ + ) + ); + } }