diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/AssertionHandler.java b/nullaway/src/main/java/com/uber/nullaway/handlers/AssertionHandler.java index d886a9cab8..61a5d9fc83 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/AssertionHandler.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/AssertionHandler.java @@ -60,7 +60,8 @@ public NullnessHint onDataflowVisitMethodInvocation( // Look for statements of the form: assertThat(A).isNotNull() or // assertThat(A).isInstanceOf(Foo.class) // A will not be NULL after this statement. - if (methodNameUtil.isMethodIsNotNull(callee) || methodNameUtil.isMethodIsInstanceOf(callee)) { + if (methodNameUtil.isMethodIsNotNull(callee, state) + || methodNameUtil.isMethodIsInstanceOf(callee, state)) { AccessPath ap = getAccessPathForNotNullAssertThatExpr(node, state, apContext); if (ap != null) { bothUpdates.set(ap, NONNULL); diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/MethodNameUtil.java b/nullaway/src/main/java/com/uber/nullaway/handlers/MethodNameUtil.java index 1a276bb428..88fc564773 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/MethodNameUtil.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/MethodNameUtil.java @@ -22,8 +22,12 @@ * THE SOFTWARE. */ +import com.google.errorprone.VisitorState; +import com.google.errorprone.suppliers.Supplier; +import com.google.errorprone.suppliers.Suppliers; import com.google.errorprone.util.ASTHelpers; import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.code.Type; import com.sun.tools.javac.util.Name; import com.uber.nullaway.annotations.Initializer; import org.checkerframework.nullaway.dataflow.cfg.node.MethodInvocationNode; @@ -41,8 +45,6 @@ class MethodNameUtil { // assertions in this handler. private static final String IS_NOT_NULL_METHOD = "isNotNull"; private static final String IS_OWNER_TRUTH_SUBJECT = "com.google.common.truth.Subject"; - private static final String IS_OWNER_ASSERTJ_ABSTRACT_ASSERT = - "org.assertj.core.api.AbstractAssert"; private static final String IS_INSTANCE_OF_METHOD = "isInstanceOf"; private static final String IS_INSTANCE_OF_ANY_METHOD = "isInstanceOfAny"; private static final String IS_TRUE_METHOD = "isTrue"; @@ -79,6 +81,9 @@ class MethodNameUtil { private static final String NULL_VALUE_MATCHER = "nullValue"; private static final String INSTANCE_OF_MATCHER = "instanceOf"; + private static final Supplier ASSERTJ_ASSERT_TYPE_SUPPLIER = + Suppliers.typeFromString("org.assertj.core.api.Assert"); + // Names of the methods (and their owners) used to identify assertions in this handler. Name used // here refers to com.sun.tools.javac.util.Name. Comparing methods using Names is faster than // comparing using strings. @@ -87,7 +92,6 @@ class MethodNameUtil { private Name isInstanceOf; private Name isInstanceOfAny; private Name isOwnerTruthSubject; - private Name isOwnerAssertJAbstractAssert; private Name isTrue; private Name isFalse; @@ -130,7 +134,6 @@ class MethodNameUtil { void initializeMethodNames(Name.Table table) { isNotNull = table.fromString(IS_NOT_NULL_METHOD); isOwnerTruthSubject = table.fromString(IS_OWNER_TRUTH_SUBJECT); - isOwnerAssertJAbstractAssert = table.fromString(IS_OWNER_ASSERTJ_ABSTRACT_ASSERT); isInstanceOf = table.fromString(IS_INSTANCE_OF_METHOD); isInstanceOfAny = table.fromString(IS_INSTANCE_OF_ANY_METHOD); @@ -172,16 +175,16 @@ void initializeMethodNames(Name.Table table) { instanceOfMatcher = table.fromString(INSTANCE_OF_MATCHER); } - boolean isMethodIsNotNull(Symbol.MethodSymbol methodSymbol) { + boolean isMethodIsNotNull(Symbol.MethodSymbol methodSymbol, VisitorState state) { return matchesMethod(methodSymbol, isNotNull, isOwnerTruthSubject) - || matchesMethod(methodSymbol, isNotNull, isOwnerAssertJAbstractAssert); + || matchesAssertJAssertMethod(methodSymbol, isNotNull, state); } - boolean isMethodIsInstanceOf(Symbol.MethodSymbol methodSymbol) { + boolean isMethodIsInstanceOf(Symbol.MethodSymbol methodSymbol, VisitorState state) { return matchesMethod(methodSymbol, isInstanceOf, isOwnerTruthSubject) - || matchesMethod(methodSymbol, isInstanceOf, isOwnerAssertJAbstractAssert) + || matchesAssertJAssertMethod(methodSymbol, isInstanceOf, state) // Truth doesn't seem to have isInstanceOfAny - || matchesMethod(methodSymbol, isInstanceOfAny, isOwnerAssertJAbstractAssert); + || matchesAssertJAssertMethod(methodSymbol, isInstanceOfAny, state); } boolean isMethodAssertTrue(Symbol.MethodSymbol methodSymbol) { @@ -303,6 +306,24 @@ private boolean matchesMethod( && methodSymbol.owner.getQualifiedName().equals(toMatchOwnerName); } + /** + * Checks if the method is an AssertJ assert method, i.e., it has the same name as + * toMatchMethodName and its owner is a subtype of AssertJ's Assert class. + * + * @param methodSymbol the method symbol to check + * @param toMatchMethodName the method name to match + * @param state the visitor state + * @return {@code true} if the method matches, {@code false} otherwise + */ + private boolean matchesAssertJAssertMethod( + Symbol.MethodSymbol methodSymbol, Name toMatchMethodName, VisitorState state) { + if (!methodSymbol.name.equals(toMatchMethodName)) { + return false; + } + return ASTHelpers.isSubtype( + methodSymbol.owner.type, ASSERTJ_ASSERT_TYPE_SUPPLIER.get(state), state); + } + boolean isUtilInitialized() { return isNotNull != null; } diff --git a/nullaway/src/test/java/com/uber/nullaway/AssertionLibsTests.java b/nullaway/src/test/java/com/uber/nullaway/AssertionLibsTests.java index bd9cc1231a..b46c92293f 100644 --- a/nullaway/src/test/java/com/uber/nullaway/AssertionLibsTests.java +++ b/nullaway/src/test/java/com/uber/nullaway/AssertionLibsTests.java @@ -526,4 +526,27 @@ public void doNotSupportAssertJAssertThatWhenDisabled() { "}") .doTest(); } + + @Test + public void collectionAssertIsNotNull() { + makeTestHelperWithArgs( + Arrays.asList( + "-d", + temporaryFolder.getRoot().getAbsolutePath(), + "-XepOpt:NullAway:AnnotatedPackages=com.uber", + "-XepOpt:NullAway:HandleTestAssertionLibraries=true")) + .addSourceLines( + "Test.java", + "import org.jspecify.annotations.*;", + "import java.util.Collection;", + "import static org.assertj.core.api.Assertions.assertThat;", + "@NullMarked", + "class Test {", + " void test(@Nullable Collection c) {", + " assertThat(c).isNotNull();", + " c.size();", + " }", + "}") + .doTest(); + } }