Skip to content

Commit 71fcd1a

Browse files
authored
Modify rule for nullable union types in generic signatures (#24129)
2 parents 23f5e32 + 676710c commit 71fcd1a

File tree

7 files changed

+45
-12
lines changed

7 files changed

+45
-12
lines changed

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ object NullOpsDecorator:
4242
}
4343
if tpStripped ne tpWiden then tpStripped else tp
4444

45-
if ctx.explicitNulls then strip(self) else self
45+
strip(self)
4646
}
4747

4848
/** Is self (after widening and dealiasing) a type of the form `T | Null`? */

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1451,7 +1451,7 @@ object Types extends TypeUtils {
14511451
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
14521452
*/
14531453
def widenUnion(using Context): Type = widen match
1454-
case tp: OrType =>
1454+
case tp: OrType if ctx.explicitNulls =>
14551455
val tp1 = tp.stripNull(stripFlexibleTypes = false)
14561456
if tp1 ne tp then
14571457
val tp1Widen = tp1.widenUnionWithoutNull

compiler/src/dotty/tools/dotc/transform/GenericSignatures.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ object GenericSignatures {
304304
builder.append(')')
305305
methodResultSig(rte)
306306

307+
case OrNull(tp1) if !tp1.derivesFrom(defn.AnyValClass) =>
308+
// Special case for nullable union types whose underlying type is not a value class.
309+
// For example, `T | Null` where `T` is a type parameter becomes `T` in the signature;
310+
// `Int | Null` still becomes `Object`.
311+
jsig1(tp1)
312+
307313
case tp: AndType =>
308314
// Only intersections appearing as the upper-bound of a type parameter
309315
// can be preserved in generic signatures and those are already
@@ -455,15 +461,15 @@ object GenericSignatures {
455461
else x
456462
}
457463

458-
private def collectMethodParams(mtd: MethodOrPoly)(using Context): (List[TypeParamInfo], List[Type], Type) =
464+
private def collectMethodParams(mtd: MethodOrPoly)(using Context): (List[TypeParamInfo], List[Type], Type) =
459465
val tparams = ListBuffer.empty[TypeParamInfo]
460466
val vparams = ListBuffer.empty[Type]
461467

462468
@tailrec def recur(tpe: Type): Type = tpe match
463469
case mtd: MethodType =>
464470
vparams ++= mtd.paramInfos.filterNot(_.hasAnnotation(defn.ErasedParamAnnot))
465471
recur(mtd.resType)
466-
case PolyType(tps, tpe) =>
472+
case PolyType(tps, tpe) =>
467473
tparams ++= tps
468474
recur(tpe)
469475
case _ =>

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ trait TypeAssigner {
176176
val qualType = qual.tpe.widenIfUnstable
177177
def kind = if tree.isType then "type" else "value"
178178
val foundWithoutNull = qualType match
179-
case OrNull(qualType1) if qualType1 <:< defn.ObjectType =>
179+
case OrNull(qualType1) if ctx.explicitNulls && qualType1 <:< defn.ObjectType =>
180180
val name = tree.name
181181
val pre = maybeSkolemizePrefix(qualType1, name)
182182
reallyExists(qualType1.findMember(name, pre))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
588588
*/
589589
def toNotNullTermRef(tree: Tree, pt: Type)(using Context): Tree = tree.tpe match
590590
case ref: TermRef
591-
if pt != LhsProto && // Ensure it is not the lhs of Assign
592-
ctx.notNullInfos.impliesNotNull(ref) &&
593-
// If a reference is in the context, it is already trackable at the point we add it.
594-
// Hence, we don't use isTracked in the next line, because checking use out of order is enough.
595-
!ref.usedOutOfOrder =>
591+
if ctx.explicitNulls
592+
&& pt != LhsProto // Ensure it is not the lhs of Assign
593+
&& ctx.notNullInfos.impliesNotNull(ref)
594+
// If a reference is in the context, it is already trackable at the point we add it.
595+
// Hence, we don't use isTracked in the next line, because checking use out of order is enough.
596+
&& !ref.usedOutOfOrder =>
596597
ref match
597598
case OrNull(tpnn) => tree.cast(AndType(ref, tpnn))
598599
case _ => tree
@@ -2228,7 +2229,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22282229
given Context = caseCtx
22292230
val case1 = typedCase(cas, sel, wideSelType, tpe)
22302231
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
2231-
if !alreadyStripped && Nullables.matchesNull(case1) then
2232+
if ctx.explicitNulls && !alreadyStripped && Nullables.matchesNull(case1) then
22322233
wideSelType = wideSelType.stripNull()
22332234
alreadyStripped = true
22342235
case1
@@ -2261,7 +2262,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
22612262
given Context = caseCtx
22622263
val case1 = typedCase(cas, sel, wideSelType, pt)
22632264
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
2264-
if !alreadyStripped && Nullables.matchesNull(case1) then
2265+
if ctx.explicitNulls && !alreadyStripped && Nullables.matchesNull(case1) then
22652266
wideSelType = wideSelType.stripNull()
22662267
alreadyStripped = true
22672268
case1
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
f1(A): A
2+
f2(A): A
3+
g1(T): T
4+
g2(T): T
5+
i(java.lang.Object): java.lang.Object
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
class C[T]:
2+
def f1[A](a: A | Null): A | Null = ???
3+
def f2[A](a: A): A = ???
4+
def g1(a: T | Null): T | Null = ???
5+
def g2(a: T): T = ???
6+
def i(a: Int | Null): Int | Null = ???
7+
8+
object Test:
9+
10+
def printGenericSignature(m: java.lang.reflect.Method): Unit =
11+
val tpe = m.getGenericParameterTypes().map(_.getTypeName).mkString(", ")
12+
val ret = m.getGenericReturnType().getTypeName
13+
println(s"${m.getName}($tpe): $ret")
14+
15+
def main(args: Array[String]): Unit =
16+
val c = classOf[C[_]]
17+
printGenericSignature(c.getDeclaredMethod("f1", classOf[Object]))
18+
printGenericSignature(c.getDeclaredMethod("f2", classOf[Object]))
19+
printGenericSignature(c.getDeclaredMethod("g1", classOf[Object]))
20+
printGenericSignature(c.getDeclaredMethod("g2", classOf[Object]))
21+
printGenericSignature(c.getDeclaredMethod("i", classOf[Object]))

0 commit comments

Comments
 (0)