@@ -49,37 +49,45 @@ object Nullables:
4949 TypeBoundsTree (lo, newHi, alias)
5050
5151 /** A set of val or var references that are known to be not null
52- * after the tree finishes executing normally (non-exceptionally),
52+ * after the tree finishes executing normally (non-exceptionally),
5353 * plus a set of variable references that are ever assigned to null,
5454 * and may therefore be null if execution of the tree is interrupted
5555 * by an exception.
5656 */
57- case class NotNullInfo (asserted : Set [TermRef ], retracted : Set [TermRef ]):
57+ case class NotNullInfo (asserted : Set [TermRef ] | Null , retracted : Set [TermRef ]):
5858 def isEmpty = this eq NotNullInfo .empty
5959
6060 def retractedInfo = NotNullInfo (Set (), retracted)
6161
62+ def terminatedInfo = NotNullInfo (null , retracted)
63+
6264 /** The sequential combination with another not-null info */
6365 def seq (that : NotNullInfo ): NotNullInfo =
6466 if this .isEmpty then that
6567 else if that.isEmpty then this
66- else NotNullInfo (
67- this .asserted.diff(that.retracted).union(that.asserted),
68- this .retracted.union(that.retracted))
68+ else
69+ val newAsserted =
70+ if this .asserted == null || that.asserted == null then null
71+ else this .asserted.diff(that.retracted).union(that.asserted)
72+ val newRetracted = this .retracted.union(that.retracted)
73+ NotNullInfo (newAsserted, newRetracted)
6974
7075 /** The alternative path combination with another not-null info. Used to merge
71- * the nullability info of the two branches of an if.
76+ * the nullability info of the branches of an if or match .
7277 */
7378 def alt (that : NotNullInfo ): NotNullInfo =
74- NotNullInfo (this .asserted.intersect(that.asserted), this .retracted.union(that.retracted))
75-
76- def withRetracted (that : NotNullInfo ): NotNullInfo =
77- NotNullInfo (this .asserted, this .retracted.union(that.retracted))
79+ val newAsserted =
80+ if this .asserted == null then that.asserted
81+ else if that.asserted == null then this .asserted
82+ else this .asserted.intersect(that.asserted)
83+ val newRetracted = this .retracted.union(that.retracted)
84+ NotNullInfo (newAsserted, newRetracted)
85+ end NotNullInfo
7886
7987 object NotNullInfo :
8088 val empty = new NotNullInfo (Set (), Set ())
81- def apply (asserted : Set [TermRef ], retracted : Set [TermRef ]): NotNullInfo =
82- if asserted.isEmpty && retracted.isEmpty then empty
89+ def apply (asserted : Set [TermRef ] | Null , retracted : Set [TermRef ]): NotNullInfo =
90+ if asserted != null && asserted .isEmpty && retracted.isEmpty then empty
8391 else new NotNullInfo (asserted, retracted)
8492 end NotNullInfo
8593
@@ -202,7 +210,7 @@ object Nullables:
202210 */
203211 @ tailrec def impliesNotNull (ref : TermRef ): Boolean = infos match
204212 case info :: infos1 =>
205- if info.asserted.contains(ref) then true
213+ if info.asserted != null && info.asserted .contains(ref) then true
206214 else if info.retracted.contains(ref) then false
207215 else infos1.impliesNotNull(ref)
208216 case _ =>
@@ -218,7 +226,9 @@ object Nullables:
218226 /** Retract all references to mutable variables */
219227 def retractMutables (using Context ) =
220228 val mutables = infos.foldLeft(Set [TermRef ]()):
221- (ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable )))
229+ (ms, info) => ms.union(
230+ if info.asserted == null then Set .empty
231+ else info.asserted.filter(_.symbol.is(Mutable )))
222232 infos.extendWith(NotNullInfo (Set (), mutables))
223233
224234 end extension
@@ -491,7 +501,10 @@ object Nullables:
491501 && assignmentSpans.getOrElse(sym.span.start, Nil ).exists(whileSpan.contains(_))
492502 && ctx.notNullInfos.impliesNotNull(ref)
493503
494- val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
504+ val retractedVars = ctx.notNullInfos.flatMap(info =>
505+ if info.asserted == null then Set .empty
506+ else info.asserted.filter(isRetracted)
507+ ).toSet
495508 ctx.addNotNullInfo(NotNullInfo (Set (), retractedVars))
496509 end whileContext
497510
0 commit comments