@@ -294,11 +294,29 @@ object Nullables:
294294 if ! info.isEmpty then tree.putAttachment(NNInfo , info)
295295 tree
296296
297+ /* Collect the nullability info from parts of `tree` */
298+ def collectNotNullInfo (using Context ): NotNullInfo = tree match
299+ case Typed (expr, _) =>
300+ expr.notNullInfo
301+ case Apply (fn, args) =>
302+ val argsInfo = args.map(_.notNullInfo)
303+ val fnInfo = fn.notNullInfo
304+ argsInfo.foldLeft(fnInfo)(_ seq _)
305+ case TypeApply (fn, _) =>
306+ fn.notNullInfo
307+ case _ =>
308+ // Other cases are handled specially in typer.
309+ NotNullInfo .empty
310+
297311 /* The nullability info of `tree` */
298312 def notNullInfo (using Context ): NotNullInfo =
299- stripInlined(tree).getAttachment(NNInfo ) match
313+ val tree1 = stripInlined(tree)
314+ tree1.getAttachment(NNInfo ) match
300315 case Some (info) if ! ctx.erasedTypes => info
301- case _ => NotNullInfo .empty
316+ case _ =>
317+ val nnInfo = tree1.collectNotNullInfo
318+ tree1.withNotNullInfo(nnInfo)
319+ nnInfo
302320
303321 /* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
304322 def notNullInfoIf (c : Boolean )(using Context ): NotNullInfo =
@@ -379,21 +397,23 @@ object Nullables:
379397 end extension
380398
381399 extension (tree : Assign )
382- def computeAssignNullable ()(using Context ): tree.type = tree.lhs match
383- case TrackedRef (ref) =>
384- val rhstp = tree.rhs.typeOpt
385- if ctx.explicitNulls && ref.isNullableUnion then
386- if rhstp.isNullType || rhstp.isNullableUnion then
387- // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
388- // lhs variable is no longer trackable. We don't need to check whether the type `T`
389- // is correct here, as typer will check it.
390- tree.withNotNullInfo(NotNullInfo (Set (), Set (ref)))
391- else
392- // If the initial type is nullable and the assigned value is non-null,
393- // we add it to the NotNull.
394- tree.withNotNullInfo(NotNullInfo (Set (ref), Set ()))
395- else tree
396- case _ => tree
400+ def computeAssignNullable ()(using Context ): tree.type =
401+ var nnInfo = tree.rhs.notNullInfo
402+ tree.lhs match
403+ case TrackedRef (ref) if ctx.explicitNulls && ref.isNullableUnion =>
404+ nnInfo = nnInfo.seq:
405+ val rhstp = tree.rhs.typeOpt
406+ if rhstp.isNullType || rhstp.isNullableUnion then
407+ // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
408+ // lhs variable is no longer trackable. We don't need to check whether the type `T`
409+ // is correct here, as typer will check it.
410+ NotNullInfo (Set (), Set (ref))
411+ else
412+ // If the initial type is nullable and the assigned value is non-null,
413+ // we add it to the NotNull.
414+ NotNullInfo (Set (ref), Set ())
415+ case _ =>
416+ tree.withNotNullInfo(nnInfo)
397417 end extension
398418
399419 private val analyzedOps = Set (nme.EQ , nme.NE , nme.eq, nme.ne, nme.ZAND , nme.ZOR , nme.UNARY_! )
0 commit comments