diff --git a/nullaway/src/main/java/com/uber/nullaway/NullabilityUtil.java b/nullaway/src/main/java/com/uber/nullaway/NullabilityUtil.java index 42b52f92d7..9a10e9a6c7 100644 --- a/nullaway/src/main/java/com/uber/nullaway/NullabilityUtil.java +++ b/nullaway/src/main/java/com/uber/nullaway/NullabilityUtil.java @@ -730,4 +730,31 @@ public static ExpressionTree stripParensAndCasts(ExpressionTree expr) { } return expr; } + + public record ExprTreeAndState(ExpressionTree expr, VisitorState state) {} + + /** + * strip out enclosing parentheses, and update the tree path in the VisitorState to point to the + * stripped expression if the original expression was the leaf of the path + * + * @param expr a potentially parenthesised expression. + * @param state the VisitorState + * @return the same expression without parentheses, and the updated VisitorState + */ + public static ExprTreeAndState stripParensAndUpdateTreePath( + ExpressionTree expr, VisitorState state) { + TreePath path = state.getPath(); + if (path.getLeaf() != expr) { + // if the expression is not the leaf of the path, we can't update the path to point to the + // stripped expression, so we just return the original expression and state + return new ExprTreeAndState(expr, state); + } + ExpressionTree resultExpr = expr; + while (resultExpr instanceof ParenthesizedTree) { + resultExpr = ((ParenthesizedTree) resultExpr).getExpression(); + path = new TreePath(path, resultExpr); + } + VisitorState resultState = path == state.getPath() ? state : state.withPath(path); + return new ExprTreeAndState(resultExpr, resultState); + } } diff --git a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java index 81066958da..95f30d1a61 100644 --- a/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java +++ b/nullaway/src/main/java/com/uber/nullaway/generics/GenericsChecks.java @@ -36,6 +36,7 @@ import com.sun.tools.javac.code.Type; import com.sun.tools.javac.code.Types; import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.TreeInfo; import com.sun.tools.javac.util.Name; import com.sun.tools.javac.util.Names; import com.uber.nullaway.CodeAnnotationInfo; @@ -423,25 +424,39 @@ private void reportInvalidOverridingMethodParamTypeError( * @return Type of the tree with preserved annotations. */ private @Nullable Type getTreeType(Tree tree, VisitorState state) { - tree = ASTHelpers.stripParentheses(tree); + if (tree instanceof ExpressionTree exprTree) { + NullabilityUtil.ExprTreeAndState exprTreeAndState = + NullabilityUtil.stripParensAndUpdateTreePath(exprTree, state); + tree = exprTreeAndState.expr(); + state = exprTreeAndState.state(); + } if (tree instanceof LambdaExpressionTree || tree instanceof MemberReferenceTree) { Type result = inferredPolyExpressionTypes.get(tree); if (result == null) { result = ASTHelpers.getType(tree); } - if (result != null && result.isRaw()) { - return null; + return typeOrNullIfRaw(result); + } + if (tree instanceof NewClassTree newClassTree) { + if (TreeInfo.isDiamond((JCTree) newClassTree)) { + if (newClassTree.getClassBody() != null) { + // Keep existing behavior for diamond anonymous classes, which are not yet fully + // supported. Tracked in https://github.com/uber/NullAway/issues/1475 + return null; + } + // For constructor calls using diamond operator, infer from assignment context. + // TODO handle diamond constructor calls passed to generic methods + // https://github.com/uber/NullAway/issues/1470 + Type fromAssignmentContext = getDiamondTypeFromContext(newClassTree, state); + if (fromAssignmentContext != null) { + return fromAssignmentContext; + } } - return result; - } - if (tree instanceof NewClassTree - && ((NewClassTree) tree).getIdentifier() instanceof ParameterizedTypeTree paramTypedTree) { - if (paramTypedTree.getTypeArguments().isEmpty()) { - // diamond operator, which we do not yet support; for now, return null - // TODO: support diamond operators - return null; + if (newClassTree.getIdentifier() instanceof ParameterizedTypeTree paramTypedTree + && !paramTypedTree.getTypeArguments().isEmpty()) { + return typeWithPreservedAnnotations(paramTypedTree); } - return typeWithPreservedAnnotations(paramTypedTree); + return typeOrNullIfRaw(ASTHelpers.getType(tree)); } else if (tree instanceof NewArrayTree && ((NewArrayTree) tree).getType() instanceof AnnotatedTypeTree) { return typeWithPreservedAnnotations(tree); @@ -514,12 +529,126 @@ private void reportInvalidOverridingMethodParamTypeError( } } } - if (result != null && result.isRaw()) { - // bail out of any checking involving raw types for now + return typeOrNullIfRaw(result); + } + } + + /** + * @param type a type to check + * @return the given type, or null if the type is a raw type + */ + private static @Nullable Type typeOrNullIfRaw(@Nullable Type type) { + if (type != null && type.isRaw()) { + return null; + } + return type; + } + + /** + * Gets the type of a constructor call using a diamond operator from its assignment context, if + * available. + */ + private @Nullable Type getDiamondTypeFromContext(NewClassTree tree, VisitorState state) { + return getDiamondTypeFromParentContext( + tree, state, castToNonNull(state.getPath().getParentPath())); + } + + /** + * Computes the assignment-context type for an inferred constructor call, given a path to its + * parent context. + */ + private @Nullable Type getDiamondTypeFromParentContext( + NewClassTree tree, VisitorState state, TreePath parentPath) { + Tree parent = parentPath.getLeaf(); + while (parent instanceof ParenthesizedTree) { + parentPath = parentPath.getParentPath(); + if (parentPath == null) { return null; } - return result; + parent = parentPath.getLeaf(); + } + if (parent instanceof VariableTree || parent instanceof AssignmentTree) { + return getTreeType(parent, state); } + if (parent instanceof ReturnTree) { + TreePath enclosingMethodOrLambda = + NullabilityUtil.findEnclosingMethodOrLambdaOrInitializer(parentPath); + if (enclosingMethodOrLambda != null + && enclosingMethodOrLambda.getLeaf() instanceof MethodTree enclosingMethod) { + Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(enclosingMethod); + if (methodSymbol != null) { + return methodSymbol.getReturnType(); + } + } + return null; + } + if (parent instanceof MethodInvocationTree parentInvocation) { + if (isGenericCallNeedingInference(parentInvocation)) { + // TODO support full integration of diamond constructor calls with generic method inference + // https://github.com/uber/NullAway/issues/1470 + // for now, just give up and return null + return null; + } + Type methodType = ASTHelpers.getType(parentInvocation.getMethodSelect()); + if (methodType == null) { + return null; + } + return getFormalParameterTypeForArgument(parentInvocation, methodType.asMethodType(), tree); + } + if (parent instanceof NewClassTree parentConstructorCall) { + // get the type returned by the parent constructor call + Type parentClassType = getTreeType(parentConstructorCall, state.withPath(parentPath)); + if (parentClassType != null) { + Symbol parentCtorSymbol = ASTHelpers.getSymbol(parentConstructorCall); + // get the proper type for the constructor, as a member of the type returned by the + // constructor + Type parentCtorType = + TypeSubstitutionUtils.memberType( + state.getTypes(), parentClassType, parentCtorSymbol, config); + return getFormalParameterTypeForArgument( + parentConstructorCall, parentCtorType.asMethodType(), tree); + } + } + if (parent instanceof ConditionalExpressionTree) { + // TODO infer diamond type from the overall conditional expression type + // tracked in https://github.com/uber/NullAway/issues/1477 + return null; + } + return null; + } + + /** + * Returns the inferred/declared formal parameter type corresponding to actual parameter {@code + * argumentTree}. + */ + private @Nullable Type getFormalParameterTypeForArgument( + Tree invocationTree, Type.MethodType invocationType, Tree argumentTree) { + AtomicReference<@Nullable Type> formalParamTypeRef = new AtomicReference<>(); + new InvocationArguments(invocationTree, invocationType) + .forEach( + (arg, pos, formalParamType, unused) -> { + if (ASTHelpers.stripParentheses(arg) == argumentTree) { + formalParamTypeRef.set(formalParamType); + } + }); + return formalParamTypeRef.get(); + } + + /** + * Returns true when javac inferred class type arguments for a constructor call, i.e. there are + * instantiated type arguments at the type level, but no explicit non-diamond source type args. + */ + private static boolean hasInferredClassTypeArguments(NewClassTree newClassTree) { + if (newClassTree.getClassBody() != null) { + // we still need to properly handle anonymous classes + return false; + } + if (!TreeInfo.isDiamond((JCTree) newClassTree)) { + // explicit class type arguments in source + return false; + } + Type newClassType = ASTHelpers.getType(newClassTree); + return newClassType != null && !newClassType.getTypeArguments().isEmpty(); } /** @@ -606,7 +735,8 @@ public void checkTypeParameterNullnessForAssignability(Tree tree, VisitorState s && isAssignmentToField(tree)) { maybeStoreLambdaTypeFromTarget(lambdaExpressionTree, lhsType); } - Type rhsType = getTreeType(rhsTree, state); + TreePath pathToRhs = new TreePath(state.getPath(), rhsTree); + Type rhsType = getTreeType(rhsTree, state.withPath(pathToRhs)); if (rhsType != null) { if (isGenericCallNeedingInference(rhsTree)) { rhsType = @@ -1298,6 +1428,7 @@ private Type updateTypeWithNullness( private static boolean isGenericCallNeedingInference(ExpressionTree argument) { // For now, we only support calls to generic methods. // TODO also support calls to generic constructors that use the diamond operator + // https://github.com/uber/NullAway/issues/1470 if (argument instanceof MethodInvocationTree methodInvocation) { Symbol.MethodSymbol methodSymbol = ASTHelpers.getSymbol(methodInvocation); // true for generic method calls with no explicit type arguments @@ -1328,7 +1459,8 @@ public void checkTypeParameterNullnessForFunctionReturnType( // bail out of any checking involving raw types for now return; } - Type returnExpressionType = getTreeType(retExpr, state); + TreePath pathToRetExpr = new TreePath(state.getPath(), retExpr); + Type returnExpressionType = getTreeType(retExpr, state.withPath(pathToRetExpr)); if (returnExpressionType != null) { if (isGenericCallNeedingInference(retExpr)) { returnExpressionType = @@ -1478,7 +1610,21 @@ public void compareGenericTypeParameterNullabilityForCall( return; } Type invokedMethodType = methodSymbol.type; - Type enclosingType = getEnclosingTypeForCallExpression(methodSymbol, tree, null, state, false); + Type enclosingType = null; + if (tree instanceof NewClassTree newClassTree) { + if (hasInferredClassTypeArguments(newClassTree)) { + TreePath currentPath = state.getPath(); + if (currentPath != null && ASTHelpers.stripParentheses(currentPath.getLeaf()) == tree) { + TreePath parentPath = currentPath.getParentPath(); + if (parentPath != null) { + enclosingType = getDiamondTypeFromParentContext(newClassTree, state, parentPath); + } + } + } + } + if (enclosingType == null) { + enclosingType = getEnclosingTypeForCallExpression(methodSymbol, tree, null, state, false); + } if (enclosingType != null) { invokedMethodType = TypeSubstitutionUtils.memberType(state.getTypes(), enclosingType, methodSymbol, config); @@ -1509,7 +1655,9 @@ public void compareGenericTypeParameterNullabilityForCall( if (inferredPolyType != null) { actualParameterType = inferredPolyType; } else { - actualParameterType = getTreeType(currentActualParam, state); + TreePath pathToActualParam = new TreePath(state.getPath(), currentActualParam); + actualParameterType = + getTreeType(currentActualParam, state.withPath(pathToActualParam)); } if (actualParameterType != null) { if (isGenericCallNeedingInference(currentActualParam)) { @@ -1874,16 +2022,12 @@ private InvocationAndContext getInvocationAndContextForInference( } // the generic invocation is either a regular parameter to the parent call, or the // receiver expression - AtomicReference<@Nullable Type> formalParamTypeRef = new AtomicReference<>(); - Type type = ASTHelpers.getSymbol(parentInvocation).type; - new InvocationArguments(parentInvocation, type.asMethodType()) - .forEach( - (arg, pos, formalParamType, unused) -> { - if (ASTHelpers.stripParentheses(arg) == invocation) { - formalParamTypeRef.set(formalParamType); - } - }); - Type formalParamType = formalParamTypeRef.get(); + Type formalParamType = + getFormalParameterTypeForArgument( + parentInvocation, + castToNonNull(ASTHelpers.getType(parentInvocation.getMethodSelect())) + .asMethodType(), + invocation); if (formalParamType == null) { // this can happen if the invocation is the receiver expression of the call, e.g., // id(x).foo() (note that foo() need not be generic) diff --git a/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericDiamondTests.java b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericDiamondTests.java new file mode 100644 index 0000000000..07c4ac7ed5 --- /dev/null +++ b/nullaway/src/test/java/com/uber/nullaway/jspecify/GenericDiamondTests.java @@ -0,0 +1,221 @@ +package com.uber.nullaway.jspecify; + +import com.google.errorprone.CompilationTestHelper; +import com.uber.nullaway.NullAwayTestsBase; +import com.uber.nullaway.generics.JSpecifyJavacConfig; +import java.util.Arrays; +import org.junit.Test; + +public class GenericDiamondTests extends NullAwayTestsBase { + + @Test + public void assignToLocal() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + @NullMarked + public class Test { + static class Foo { + static Foo<@Nullable Void> make() { + throw new RuntimeException(); + } + static Foo<@Nullable String> makeNullableStr() { + throw new RuntimeException(); + } + } + static class Bar { + Bar(Foo foo) { + } + } + void testNegative() { + // should be legal + Bar<@Nullable Void> b = new Bar<>(Foo.make()); + } + void testPositive() { + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable String> cannot be converted to Foo + Bar b = new Bar<>(Foo.makeNullableStr()); + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable Void> cannot be converted to Foo + Bar b2 = new Bar<>(Foo.make()); + } + } + """) + .doTest(); + } + + @Test + public void returnDiamond() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + @NullMarked + public class Test { + static class Foo { + static Foo<@Nullable Void> make() { + throw new RuntimeException(); + } + static Foo<@Nullable String> makeNullableStr() { + throw new RuntimeException(); + } + } + static class Bar { + Bar(Foo foo) { + } + } + Bar<@Nullable Void> testNegative() { + // should be legal + return new Bar<>(Foo.make()); + } + Bar testPositive() { + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable String> cannot be converted to Foo + return new Bar<>(Foo.makeNullableStr()); + } + } + """) + .doTest(); + } + + @Test + public void paramPassing() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + @NullMarked + public class Test { + static class Foo { + static Foo<@Nullable Void> make() { + throw new RuntimeException(); + } + static Foo<@Nullable String> makeNullableStr() { + throw new RuntimeException(); + } + } + static class Bar { + Bar(Foo foo) { + } + } + static void takeNullableVoid(Bar<@Nullable Void> b) {} + static void takeStr(Bar b) {} + void testNegative() { + // should be legal + takeNullableVoid(new Bar<>(Foo.make())); + } + void testPositive() { + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable String> cannot be converted to Foo + takeStr(new Bar<>(Foo.makeNullableStr())); + } + } + """) + .doTest(); + } + + @Test + public void parenthesizedDiamond() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + @NullMarked + public class Test { + static class Foo { + static Foo<@Nullable Void> make() { + throw new RuntimeException(); + } + static Foo<@Nullable String> makeNullableStr() { + throw new RuntimeException(); + } + } + static class Bar { + Bar(Foo foo) { + } + } + Bar<@Nullable Void> testNegative() { + return (new Bar<>(Foo.make())); + } + Bar testPositive() { + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable String> cannot be converted to Foo + return (new Bar<>(Foo.makeNullableStr())); + } + } + """) + .doTest(); + } + + @Test + public void nestedDiamondConstructors() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + @NullMarked + public class Test { + static class Foo { + static Foo<@Nullable Void> make() { + throw new RuntimeException(); + } + static Foo<@Nullable String> makeNullableStr() { + throw new RuntimeException(); + } + } + static class Bar { + Bar(Foo foo) { + } + } + static class Baz { + Baz(Bar bar) { + } + } + Baz<@Nullable Void> testNegative() { + return new Baz<>(new Bar<>(Foo.make())); + } + Baz testPositive() { + // BUG: Diagnostic contains: incompatible types: Foo<@Nullable String> cannot be converted to Foo + return new Baz<>(new Bar<>(Foo.makeNullableStr())); + } + } + """) + .doTest(); + } + + @Test + public void diamondSubclassPassedToGenericMethod() { + makeHelper() + .addSourceLines( + "Test.java", + """ + import org.jspecify.annotations.*; + import java.util.List; + @NullMarked + public class Test { + interface Foo { + } + static class FooImpl implements Foo<@Nullable T> { + FooImpl(Class cls) { + } + } + static List make(Foo foo) { + throw new RuntimeException(); + } + static List<@Nullable V> test(Class cls) { + return make(new FooImpl<>(cls)); + } + } + """) + .doTest(); + } + + private CompilationTestHelper makeHelper() { + return makeTestHelperWithArgs( + JSpecifyJavacConfig.withJSpecifyModeArgs( + Arrays.asList( + "-XepOpt:NullAway:AnnotatedPackages=com.uber", + "-XepOpt:NullAway:WarnOnGenericInferenceFailure=true"))); + } +}