@@ -53,37 +53,45 @@ object Nullables:
5353 TypeBoundsTree (lo, hiTree, alias)
5454
5555 /** A set of val or var references that are known to be not null
56- * after the tree finishes executing normally (non-exceptionally),
56+ * after the tree finishes executing normally (non-exceptionally),
5757 * plus a set of variable references that are ever assigned to null,
5858 * and may therefore be null if execution of the tree is interrupted
5959 * by an exception.
6060 */
61- case class NotNullInfo (asserted : Set [TermRef ], retracted : Set [TermRef ]):
61+ case class NotNullInfo (asserted : Set [TermRef ] | Null , retracted : Set [TermRef ]):
6262 def isEmpty = this eq NotNullInfo .empty
6363
6464 def retractedInfo = NotNullInfo (Set (), retracted)
6565
66+ def terminatedInfo = NotNullInfo (null , retracted)
67+
6668 /** The sequential combination with another not-null info */
6769 def seq (that : NotNullInfo ): NotNullInfo =
6870 if this .isEmpty then that
6971 else if that.isEmpty then this
70- else NotNullInfo (
71- this .asserted.diff(that.retracted).union(that.asserted),
72- this .retracted.union(that.retracted))
72+ else
73+ val newAsserted =
74+ if this .asserted == null || that.asserted == null then null
75+ else this .asserted.diff(that.retracted).union(that.asserted)
76+ val newRetracted = this .retracted.union(that.retracted)
77+ NotNullInfo (newAsserted, newRetracted)
7378
7479 /** The alternative path combination with another not-null info. Used to merge
75- * the nullability info of the two branches of an if.
80+ * the nullability info of the branches of an if or match .
7681 */
7782 def alt (that : NotNullInfo ): NotNullInfo =
78- NotNullInfo (this .asserted.intersect(that.asserted), this .retracted.union(that.retracted))
79-
80- def withRetracted (that : NotNullInfo ): NotNullInfo =
81- NotNullInfo (this .asserted, this .retracted.union(that.retracted))
83+ val newAsserted =
84+ if this .asserted == null then that.asserted
85+ else if that.asserted == null then this .asserted
86+ else this .asserted.intersect(that.asserted)
87+ val newRetracted = this .retracted.union(that.retracted)
88+ NotNullInfo (newAsserted, newRetracted)
89+ end NotNullInfo
8290
8391 object NotNullInfo :
8492 val empty = new NotNullInfo (Set (), Set ())
85- def apply (asserted : Set [TermRef ], retracted : Set [TermRef ]): NotNullInfo =
86- if asserted.isEmpty && retracted.isEmpty then empty
93+ def apply (asserted : Set [TermRef ] | Null , retracted : Set [TermRef ]): NotNullInfo =
94+ if asserted != null && asserted .isEmpty && retracted.isEmpty then empty
8795 else new NotNullInfo (asserted, retracted)
8896 end NotNullInfo
8997
@@ -227,7 +235,7 @@ object Nullables:
227235 */
228236 @ tailrec def impliesNotNull (ref : TermRef ): Boolean = infos match
229237 case info :: infos1 =>
230- if info.asserted.contains(ref) then true
238+ if info.asserted != null && info.asserted .contains(ref) then true
231239 else if info.retracted.contains(ref) then false
232240 else infos1.impliesNotNull(ref)
233241 case _ =>
@@ -243,7 +251,9 @@ object Nullables:
243251 /** Retract all references to mutable variables */
244252 def retractMutables (using Context ) =
245253 val mutables = infos.foldLeft(Set [TermRef ]()):
246- (ms, info) => ms.union(info.asserted.filter(_.symbol.is(Mutable )))
254+ (ms, info) => ms.union(
255+ if info.asserted == null then Set .empty
256+ else info.asserted.filter(_.symbol.is(Mutable )))
247257 infos.extendWith(NotNullInfo (Set (), mutables))
248258
249259 end extension
@@ -516,7 +526,10 @@ object Nullables:
516526 && assignmentSpans.getOrElse(sym.span.start, Nil ).exists(whileSpan.contains(_))
517527 && ctx.notNullInfos.impliesNotNull(ref)
518528
519- val retractedVars = ctx.notNullInfos.flatMap(_.asserted.filter(isRetracted)).toSet
529+ val retractedVars = ctx.notNullInfos.flatMap(info =>
530+ if info.asserted == null then Set .empty
531+ else info.asserted.filter(isRetracted)
532+ ).toSet
520533 ctx.addNotNullInfo(NotNullInfo (Set (), retractedVars))
521534 end whileContext
522535
0 commit comments