diff --git a/src/main/java/org/openrewrite/staticanalysis/AnnotateNullableMethods.java b/src/main/java/org/openrewrite/staticanalysis/AnnotateNullableMethods.java new file mode 100644 index 0000000000..94763084af --- /dev/null +++ b/src/main/java/org/openrewrite/staticanalysis/AnnotateNullableMethods.java @@ -0,0 +1,162 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.jspecify.annotations.Nullable; +import org.openrewrite.Cursor; +import org.openrewrite.ExecutionContext; +import org.openrewrite.Recipe; +import org.openrewrite.TreeVisitor; +import org.openrewrite.java.*; +import org.openrewrite.java.service.AnnotationService; +import org.openrewrite.java.tree.Expression; +import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +public class AnnotateNullableMethods extends Recipe { + + private static final String NULLABLE_ANN_CLASS = "org.jspecify.annotations.Nullable"; + private static final AnnotationMatcher NULLABLE_ANNOTATION_MATCHER = new AnnotationMatcher("@" + NULLABLE_ANN_CLASS); + + @Override + public String getDisplayName() { + return "Annotate methods which may return `null` with `@Nullable`"; + } + + @Override + public String getDescription() { + return "Add the `@org.jspecify.annotation.Nullable` to non-private methods that may return `null`. " + + "This recipe scans for methods that do not already have a `@Nullable` annotation and checks their return " + + "statements for potential null values. It also identifies known methods from standard libraries that may " + + "return null, such as methods from `Map`, `Queue`, `Deque`, `NavigableSet`, and `Spliterator`. " + + "The return of streams, or lambdas are not taken into account."; + } + + @Override + public TreeVisitor getVisitor() { + return new JavaIsoVisitor() { + @Override + public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration methodDeclaration, ExecutionContext ctx) { + if (!methodDeclaration.hasModifier(J.Modifier.Type.Public) || + methodDeclaration.getMethodType() == null || + methodDeclaration.getMethodType().getReturnType() instanceof JavaType.Primitive || + service(AnnotationService.class).matches(getCursor(), NULLABLE_ANNOTATION_MATCHER) || + (methodDeclaration.getReturnTypeExpression() != null && + service(AnnotationService.class).matches(new Cursor(null, methodDeclaration.getReturnTypeExpression()), NULLABLE_ANNOTATION_MATCHER))) { + return methodDeclaration; + } + + J.MethodDeclaration md = super.visitMethodDeclaration(methodDeclaration, ctx); + updateCursor(md); + if (FindNullableReturnStatements.find(md.getBody(), getCursor().getParentTreeCursor())) { + J.MethodDeclaration annotatedMethod = JavaTemplate.builder("@" + NULLABLE_ANN_CLASS) + .javaParser(JavaParser.fromJavaVersion().dependsOn( + "package org.jspecify.annotations;public @interface Nullable {}")) + .build() + .apply(getCursor(), md.getCoordinates().addAnnotation(Comparator.comparing(J.Annotation::getSimpleName))); + doAfterVisit(ShortenFullyQualifiedTypeReferences.modifyOnly(annotatedMethod)); + return (J.MethodDeclaration) new NullableOnMethodReturnType().getVisitor().visitNonNull(annotatedMethod, ctx, getCursor().getParentTreeCursor()); + } + return md; + } + }; + } + + private static class FindNullableReturnStatements extends JavaIsoVisitor { + + private static final List KNOWN_NULLABLE_METHODS = Arrays.asList( + new MethodMatcher("java.util.Map computeIfAbsent(..)"), + new MethodMatcher("java.util.Map computeIfPresent(..)"), + new MethodMatcher("java.util.Map get(..)"), + new MethodMatcher("java.util.Map merge(..)"), + new MethodMatcher("java.util.Map put(..)"), + new MethodMatcher("java.util.Map putIfAbsent(..)"), + + new MethodMatcher("java.util.Queue poll(..)"), + new MethodMatcher("java.util.Queue peek(..)"), + + new MethodMatcher("java.util.Deque peekFirst(..)"), + new MethodMatcher("java.util.Deque pollFirst(..)"), + new MethodMatcher("java.util.Deque peekLast(..)"), + + new MethodMatcher("java.util.NavigableSet lower(..)"), + new MethodMatcher("java.util.NavigableSet floor(..)"), + new MethodMatcher("java.util.NavigableSet ceiling(..)"), + new MethodMatcher("java.util.NavigableSet higher(..)"), + new MethodMatcher("java.util.NavigableSet pollFirst(..)"), + new MethodMatcher("java.util.NavigableSet pollLast(..)"), + + new MethodMatcher("java.util.NavigableMap lowerEntry(..)"), + new MethodMatcher("java.util.NavigableMap floorEntry(..)"), + new MethodMatcher("java.util.NavigableMap ceilingEntry(..)"), + new MethodMatcher("java.util.NavigableMap higherEntry(..)"), + new MethodMatcher("java.util.NavigableMap lowerKey(..)"), + new MethodMatcher("java.util.NavigableMap floorKey(..)"), + new MethodMatcher("java.util.NavigableMap ceilingKey(..)"), + new MethodMatcher("java.util.NavigableMap higherKey(..)"), + new MethodMatcher("java.util.NavigableMap firstEntry(..)"), + new MethodMatcher("java.util.NavigableMap lastEntry(..)"), + new MethodMatcher("java.util.NavigableMap pollFirstEntry(..)"), + new MethodMatcher("java.util.NavigableMap pollLastEntry(..)"), + + new MethodMatcher("java.util.Spliterator trySplit(..)") + ); + + static boolean find(@Nullable J subtree, Cursor parentTreeCursor) { + return new FindNullableReturnStatements().reduce(subtree, new AtomicBoolean(), parentTreeCursor).get(); + } + + @Override + public J.Lambda visitLambda(J.Lambda lambda, AtomicBoolean atomicBoolean) { + // Do not evaluate return statements in lambdas + return lambda; + } + + @Override + public J.Return visitReturn(J.Return retrn, AtomicBoolean found) { + if (found.get()) { + return retrn; + } + J.Return r = super.visitReturn(retrn, found); + found.set(maybeIsNull(r.getExpression())); + return r; + } + + private boolean maybeIsNull(@Nullable Expression returnExpression) { + if (returnExpression instanceof J.Literal) { + return ((J.Literal) returnExpression).getValue() == null; + } + if (returnExpression instanceof J.MethodInvocation) { + return isKnowNullableMethod((J.MethodInvocation) returnExpression); + } + return false; + } + + private boolean isKnowNullableMethod(J.MethodInvocation methodInvocation) { + for (MethodMatcher m : KNOWN_NULLABLE_METHODS) { + if (m.matches(methodInvocation)) { + return true; + } + } + return false; + } + } +} diff --git a/src/test/java/org/openrewrite/staticanalysis/AnnotateNullableMethodsTest.java b/src/test/java/org/openrewrite/staticanalysis/AnnotateNullableMethodsTest.java new file mode 100644 index 0000000000..59d3848fbd --- /dev/null +++ b/src/test/java/org/openrewrite/staticanalysis/AnnotateNullableMethodsTest.java @@ -0,0 +1,252 @@ +/* + * Copyright 2024 the original author or authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + *

+ * https://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.openrewrite.staticanalysis; + +import org.junit.jupiter.api.Test; +import org.openrewrite.DocumentExample; +import org.openrewrite.java.JavaParser; +import org.openrewrite.test.RecipeSpec; +import org.openrewrite.test.RewriteTest; + +import static org.openrewrite.java.Assertions.java; + +class AnnotateNullableMethodsTest implements RewriteTest { + + @Override + public void defaults(RecipeSpec spec) { + spec + .recipe(new AnnotateNullableMethods()) + .parser(JavaParser.fromJavaVersion().classpath("jspecify")); + } + + @DocumentExample + @Test + void methodReturnsNullLiteral() { + rewriteRun( + //language=java + java( + """ + public class Test { + + public String getString() { + return null; + } + + public String getStringWithMultipleReturn() { + if (System.currentTimeMillis() % 2 == 0) { + return "Not null"; + } + return null; + } + } + """, + """ + import org.jspecify.annotations.Nullable; + + public class Test { + + public @Nullable String getString() { + return null; + } + + public @Nullable String getStringWithMultipleReturn() { + if (System.currentTimeMillis() % 2 == 0) { + return "Not null"; + } + return null; + } + } + """ + ) + ); + } + + @Test + void methodReturnNullButIsAlreadyAnnotated() { + rewriteRun( + //language=java + java( + """ + import org.jspecify.annotations.Nullable; + + public class Test { + public @Nullable String getString() { + return null; + } + + public @Nullable String getStringWithMultipleReturn() { + if (System.currentTimeMillis() % 2 == 0) { + return "Not null"; + } + return null; + } + } + """ + ) + ); + } + + @Test + void methodDoesNotReturnNull() { + rewriteRun( + //language=java + java( + """ + package org.example; + + public class Test { + public String getString() { + return "Hello"; + } + } + """ + ) + ); + } + + @Test + void methodReturnsDelegateKnowNullableMethod() { + rewriteRun( + //language=java + java( + """ + import java.util.Map; + + public class Test { + + public String getString(Map map) { + return map.get("key"); + } + } + """, + """ + import org.jspecify.annotations.Nullable; + + import java.util.Map; + + public class Test { + + public @Nullable String getString(Map map) { + return map.get("key"); + } + } + """ + ) + ); + } + + @Test + void methodWithLambdaShouldNotBeAnnotated() { + rewriteRun( + //language=java + java( + """ + import java.util.stream.Stream; + class A { + public Runnable getRunnable() { + return () -> null; + } + + public Integer someStream(){ + // Stream with lambda class. + return Stream.of(1, 2, 3) + .map(i -> {if (i == 2) return null; else return i;}) + .reduce((a, b) -> a + b) + .orElse(null); + } + } + """ + ) + ); + } + + @Test + void privateMethodsShouldNotBeAnnotated() { + rewriteRun( + //language=java + java( + """ + public class Test { + private String getString() { + return null; + } + } + """ + ) + ); + } + + @Test + void returnWithinNewClass() { + rewriteRun( + //language=java + java( + """ + import java.util.concurrent.Callable; + + public class Test { + + public Callable getString() { + return new Callable() { + @Override + public String call() throws Exception { + return null; + } + }; + } + + } + """, + """ + import org.jspecify.annotations.Nullable; + + import java.util.concurrent.Callable; + + public class Test { + + public Callable getString() { + return new Callable() { + + @Override + public @Nullable String call() throws Exception { + return null; + } + }; + } + + } + """ + ) + ); + } + + @Test + void returnStaticNestInnerClassAnnotation() { + rewriteRun( + //language=java + java( + """ + import org.jspecify.annotations.Nullable; + + public class Outer { + public static Outer.@Nullable Inner test() { return null; } + static class Inner {} + } + """ + ) + ); + } +}