22
33import static com .google .common .base .Verify .verify ;
44import static com .uber .nullaway .NullabilityUtil .castToNonNull ;
5+ import static com .uber .nullaway .NullabilityUtil .findEnclosingMethodOrLambdaOrInitializer ;
56
67import com .google .errorprone .VisitorState ;
78import com .google .errorprone .suppliers .Supplier ;
4041import java .util .List ;
4142import java .util .Map ;
4243import java .util .Objects ;
44+ import java .util .stream .Collectors ;
4345import javax .lang .model .type .ExecutableType ;
4446import javax .lang .model .type .TypeKind ;
4547import javax .lang .model .type .TypeVariable ;
@@ -426,7 +428,7 @@ private static void reportInvalidOverridingMethodParamTypeError(
426428 * @param state the visitor state
427429 */
428430 public void checkTypeParameterNullnessForAssignability (
429- Tree tree , NullAway analysis , VisitorState state ) {
431+ Tree tree , NullAway analysis , VisitorState state , Config config ) {
430432 if (!analysis .getConfig ().isJSpecifyMode ()) {
431433 return ;
432434 }
@@ -450,13 +452,20 @@ public void checkTypeParameterNullnessForAssignability(
450452 // method call has a return type of class type
451453 if (methodSymbol .getReturnType () instanceof Type .ClassType ) {
452454 Type .ClassType returnType = (Type .ClassType ) methodSymbol .getReturnType ();
453- List <Type > rhsTypeArguments = returnType .getTypeArguments ();
455+ List <Symbol . TypeVariableSymbol > typeParam = methodSymbol .getTypeParameters ();
456+ List <Type > returnTypeTypeArg = returnType .getTypeArguments ();
457+
454458 // if generic type in return type
455- if (!rhsTypeArguments .isEmpty ()) {
459+ if (!typeParam .isEmpty ()) {
456460 Map <Type , Type > genericNullness = new HashMap <>();
457- for (int i = 0 ; i < rhsTypeArguments .size (); i ++) {
458- Type lhsInferredType = ASTHelpers .getType (lhsTypeArguments .get (i ));
459- genericNullness .put (rhsTypeArguments .get (i ), lhsInferredType );
461+ for (int i = 0 ; i < typeParam .size (); i ++) {
462+ Type upperBound = typeParam .get (i ).type .getUpperBound ();
463+ if (getTypeNullness (upperBound , config ) == Nullness .NULLABLE ) { // generic has nullable upperbound
464+ Type lhsInferredType = inferMethodTypeArgument (typeParam .get (i ).type , lhsTypeArguments , returnTypeTypeArg , state );
465+ if (lhsInferredType != null ) { // && has a nullable upperbound
466+ genericNullness .put (typeParam .get (i ).type , lhsInferredType );
467+ }
468+ }
460469 }
461470 inferredTypes .put (rhsTree , genericNullness );
462471 }
@@ -485,10 +494,13 @@ public void checkTypeParameterNullnessForAssignability(
485494
486495 if (inferredTypes .containsKey (rhsTree )) {
487496 Map <Type , Type > genericNullness = inferredTypes .get (rhsTree );
497+ List <Type > parameterTypes = rhsType .getTypeArguments ();
488498 for (int i = 0 ; i < typeParam .size (); i ++) {
489- if ( genericNullness . containsKey ( typeParam .get (i ).type )) {
490- var pType = typeParam . get ( i ). type ;
499+ Type pType = typeParam .get (i ).type ;
500+ if ( genericNullness . containsKey ( pType )) {
491501 newTypeArgument .add (genericNullness .get (pType )); // replace type to inferred types
502+ } else {
503+ newTypeArgument .add (parameterTypes .get (i ));
492504 }
493505 }
494506
@@ -511,6 +523,29 @@ public void checkTypeParameterNullnessForAssignability(
511523 }
512524 }
513525
526+ private Type inferMethodTypeArgument (Type typeParam , List <? extends Tree > lhsTypeArg , List <Type > typeArg , VisitorState state ) {
527+ // base case
528+ if (typeParam == null || lhsTypeArg == null || typeArg == null ) {
529+ return null ;
530+ }
531+
532+ // recursive case
533+ Type inferType = null ;
534+ for (int i =0 ; i <typeArg .size (); i ++) {
535+ Type type = typeArg .get (i );
536+ if (state .getTypes ().isSameType (typeParam , type )) {
537+ return ASTHelpers .getType (lhsTypeArg .get (i ));
538+ } else if (!type .getTypeArguments ().isEmpty ()) {
539+ // instanceof Type.ForAll TODO: check if the lhsTypeArg is a generic class? -> maybe the base case makes it unnecessary
540+ inferType = inferMethodTypeArgument (typeParam , ((ParameterizedTypeTree ) lhsTypeArg .get (i )).getTypeArguments (), typeArg .get (i ).getTypeArguments (), state );
541+ if (inferType != null ) {
542+ return inferType ;
543+ }
544+ }
545+ }
546+ return inferType ;
547+ }
548+
514549 /**
515550 * Checks that the nullability of type parameters for a returned expression matches that of the
516551 * type parameters of the enclosing method's return type.
0 commit comments