Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,13 @@ static J.MemberReference newInstanceMethodReference(Expression containing, Strin
);
}

static J.@Nullable FieldAccess newClassLiteral(@Nullable JavaType type, boolean qualified) {
JavaType.Class classType = getClassType(type);
if (classType == null) {
return null;
}

JavaType.Parameterized parameterized = new JavaType.Parameterized(null, classType, singletonList(type));
static J.FieldAccess newClassLiteral(JavaType.Class classType, JavaType originalType, J instanceOfClass) {
JavaType.Parameterized parameterized = new JavaType.Parameterized(null, classType, singletonList(originalType));
return new J.FieldAccess(
randomId(),
Space.EMPTY,
Markers.EMPTY,
className(type, qualified),
instanceOfClass.withPrefix(Space.EMPTY), // Use the original expression directly
new JLeftPadded<>(
Space.EMPTY,
new J.Identifier(randomId(), Space.EMPTY, Markers.EMPTY, emptyList(), "class", parameterized, null),
Expand All @@ -141,7 +136,7 @@ static J.MemberReference newInstanceMethodReference(Expression containing, Strin
);
}

private static JavaType.@Nullable Class getClassType(@Nullable JavaType type) {
static JavaType.@Nullable Class getClassType(@Nullable JavaType type) {
if (type instanceof JavaType.Class) {
JavaType.Class classType = (JavaType.Class) type;
if ("java.lang.Class".equals(classType.getFullyQualifiedName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,11 @@ public J visitLambda(J.Lambda lambda, ExecutionContext ctx) {
J j = instanceOf.getClazz();
if ((j instanceof J.Identifier || j instanceof J.FieldAccess) &&
instanceOf.getExpression() instanceof J.Identifier) {
J.FieldAccess classLiteral = newClassLiteral(((TypeTree) j).getType(), j instanceof J.FieldAccess);
if (classLiteral != null) {
// Create the class literal directly from the original expression
JavaType originalType = ((TypeTree) j).getType();
JavaType.Class classType = getClassType(originalType);
if (classType != null) {
J.FieldAccess classLiteral = newClassLiteral(classType, originalType, j);
//noinspection DataFlowIssue
JavaType.FullyQualified rawClassType = ((JavaType.Parameterized) classLiteral.getType()).getType();
Optional<JavaType.Method> isInstanceMethod = rawClassType.getMethods().stream().filter(m -> "isInstance".equals(m.getName())).findFirst();
Expand All @@ -123,11 +126,13 @@ public J visitLambda(J.Lambda lambda, ExecutionContext ctx) {
J tree = j.getTree();
if ((tree instanceof J.Identifier || tree instanceof J.FieldAccess) &&
!(j.getType() instanceof JavaType.GenericTypeVariable)) {
J.FieldAccess classLiteral = newClassLiteral(((Expression) tree).getType(), tree instanceof J.FieldAccess);
if (classLiteral != null) {
// Create the class literal directly from the original expression
JavaType.Class classType = getClassType(((Expression) tree).getType());
if (classType != null) {
J.FieldAccess classLiteral = newClassLiteral(classType, ((Expression) tree).getType(), tree);
//noinspection DataFlowIssue
JavaType.FullyQualified classType = ((JavaType.Parameterized) classLiteral.getType()).getType();
Optional<JavaType.Method> castMethod = classType.getMethods().stream().filter(m -> "cast".equals(m.getName())).findFirst();
JavaType.FullyQualified fullClassType = ((JavaType.Parameterized) classLiteral.getType()).getType();
Optional<JavaType.Method> castMethod = fullClassType.getMethods().stream().filter(m -> "cast".equals(m.getName())).findFirst();
if (castMethod.isPresent()) {
J.MemberReference updated = newInstanceMethodReference(classLiteral, castMethod.get(), lambda.getType()).withPrefix(lambda.getPrefix());
doAfterVisit(service(ImportService.class).shortenFullyQualifiedTypeReferencesIn(updated));
Expand Down Expand Up @@ -252,9 +257,7 @@ private boolean hasSelectWhoseReferenceMightChange(MethodCall method) {
JavaType.Variable fieldType = ((J.FieldAccess) select).getName().getFieldType();
return fieldType != null && fieldType.getOwner() instanceof JavaType.Class && !fieldType.hasFlags(Flag.Final);
}
if (select instanceof J.NewClass || select instanceof J.Parentheses) {
return true;
}
return select instanceof J.NewClass || select instanceof J.Parentheses;
}
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1487,4 +1487,93 @@ public void run() {
)
);
}

@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/744")
@Test
void methodReferenceNestedClassImportInvalid() {
rewriteRun(
//language=java
java(
"""
import java.util.List;
import java.util.Map.Entry;

public class Foo {
private void foo() {
List.of().stream().filter(i -> i instanceof Entry);
}
}
""",
"""
import java.util.List;
import java.util.Map.Entry;

public class Foo {
private void foo() {
List.of().stream().filter(Entry.class::isInstance);
}
}
"""
)
);
}

@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/744")
@Test
void methodReferenceNestedClassFullyQualified() {
rewriteRun(
//language=java
java(
"""
import java.util.List;

public class Foo {
private void foo() {
List.of().stream().filter(i -> i instanceof java.util.Map.Entry);
}
}
""",
"""
import java.util.List;
import java.util.Map;

public class Foo {
private void foo() {
List.of().stream().filter(Map.Entry.class::isInstance);
}
}
"""
)
);
}

@Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/744")
@Test
void methodReferenceNestedClassCast() {
rewriteRun(
//language=java
java(
"""
import java.util.List;
import java.util.Map.Entry;

public class Foo {
private void foo() {
List.of().stream().map(i -> (Entry) i);
}
}
""",
"""
import java.util.List;
import java.util.Map.Entry;

public class Foo {
private void foo() {
List.of().stream().map(Entry.class::cast);
}
}
"""
)
);
}
}