Skip to content

Commit 912dabe

Browse files
committed
Fix #405: Completely overhaul erasure of value classes.
Including when they contain arrays or are contained in arrays.
1 parent b5c32dd commit 912dabe

File tree

10 files changed

+375
-108
lines changed

10 files changed

+375
-108
lines changed

build.sbt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ lazy val tastyQuery =
114114
import com.typesafe.tools.mima.core.*
115115
Seq(
116116
// private, not an issue
117+
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass"),
118+
ProblemFilters.exclude[MissingClassProblem]("tastyquery.Erasure$ErasedValueClass$"),
117119
ProblemFilters.exclude[MissingClassProblem]("tastyquery.TypeOps$TypeFold"),
120+
// private[tastyquery], not an issue
121+
ProblemFilters.exclude[DirectMissingMethodProblem]("tastyquery.Signatures#Signature.toSigName"),
118122
// Everything in tastyquery.reader is private[tastyquery] at most
119123
ProblemFilters.exclude[Problem]("tastyquery.reader.*"),
120124
)

tasty-query/shared/src/main/scala/tastyquery/Definitions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,15 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
396396
lazy val CharClass = scalaPackage.requiredClass("Char")
397397
lazy val UnitClass = scalaPackage.requiredClass("Unit")
398398

399+
private[tastyquery] lazy val BoxedBooleanClass = javaLangPackage.requiredClass("Boolean")
400+
private[tastyquery] lazy val BoxedCharClass = javaLangPackage.requiredClass("Character")
401+
private[tastyquery] lazy val BoxedByteClass = javaLangPackage.requiredClass("Byte")
402+
private[tastyquery] lazy val BoxedShortClass = javaLangPackage.requiredClass("Short")
403+
private[tastyquery] lazy val BoxedIntClass = javaLangPackage.requiredClass("Integer")
404+
private[tastyquery] lazy val BoxedLongClass = javaLangPackage.requiredClass("Long")
405+
private[tastyquery] lazy val BoxedFloatClass = javaLangPackage.requiredClass("Float")
406+
private[tastyquery] lazy val BoxedDoubleClass = javaLangPackage.requiredClass("Double")
407+
399408
lazy val StringClass = javaLangPackage.requiredClass("String")
400409

401410
lazy val ProductClass = scalaPackage.requiredClass("Product")

tasty-query/shared/src/main/scala/tastyquery/Erasure.scala

Lines changed: 184 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,6 @@ import tastyquery.Types.*
1111
import tastyquery.Types.ErasedTypeRef.*
1212

1313
private[tastyquery] object Erasure:
14-
// TODO: improve this to match dotty:
15-
// - use correct type erasure algorithm from Scala 3, with specialisations
16-
// for Java types and Scala 2 types (i.e. varargs, value-classes)
17-
1814
@deprecated("use the overload that takes an explicit SourceLanguage", since = "0.7.1")
1915
def erase(tpe: Type)(using Context): ErasedTypeRef =
2016
erase(tpe, SourceLanguage.Scala3)
@@ -27,44 +23,41 @@ private[tastyquery] object Erasure:
2723
finishErase(preErase(tpe, keepUnit))
2824
end erase
2925

30-
/** First pass of erasure, where some special types are preserved as is.
26+
private[tastyquery] def eraseForSigName(tpe: Type, language: SourceLanguage, keepUnit: Boolean)(
27+
using Context
28+
): ErasedTypeRef =
29+
given SourceLanguage = language
30+
31+
val patchedPreErased = preErase(tpe, keepUnit) match
32+
case ArrayTypeRef(ClassRef(cls), dimensions) if cls.isDerivedValueClass =>
33+
// Hack! dotc's `sigName` does *not* correspond to erasure in this case!
34+
val patchedBase =
35+
if cls.typeParams.isEmpty then preEraseMonoValueClass(cls)
36+
else preErasePolyValueClass(cls, cls.typeParams.map(_.localRef))
37+
patchedBase.underlying.multiArrayOf(dimensions)
38+
case typeRef =>
39+
typeRef
40+
41+
finishErase(patchedPreErased)
42+
end eraseForSigName
43+
44+
private final case class ErasedValueClass(valueClass: ClassSymbol, underlying: ErasedTypeRef)
45+
46+
private type PreErasedTypeRef = ErasedTypeRef | ErasedValueClass
47+
48+
/** First pass of erasure, where some special types are preserved as is,
49+
* and where value classes become `ErasedValueClass`es.
3150
*
3251
* In particular, `Any` is preserved as `Any`, instead of becoming
3352
* `java.lang.Object`.
3453
*/
35-
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): ErasedTypeRef =
36-
def arrayOfBounds(bounds: TypeBounds): ErasedTypeRef =
37-
preErase(bounds.high, keepUnit = false) match
38-
case ClassRef(cls) if cls.isAny || cls.isAnyVal =>
39-
ClassRef(defn.ObjectClass)
40-
case typeRef =>
41-
typeRef.arrayOf()
42-
43-
def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef = tpe match
44-
case tpe: AppliedType =>
45-
tpe.tycon match
46-
case TypeRef.OfClass(cls) =>
47-
if cls.isArray then
48-
val List(targ) = tpe.args: @unchecked
49-
arrayOf(targ).arrayOf()
50-
else ClassRef(cls).arrayOf()
51-
case _ =>
52-
arrayOf(tpe.translucentSuperType)
53-
case TypeRef.OfClass(cls) =>
54-
if cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass).arrayOf()
55-
else ClassRef(cls).arrayOf()
56-
case tpe: TypeRef =>
57-
tpe.optSymbol match
58-
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
59-
arrayOf(tpe.translucentSuperType)
60-
case _ =>
61-
tpe.bounds match
62-
case bounds: AbstractTypeBounds => arrayOfBounds(bounds)
63-
case TypeAlias(alias) => arrayOf(alias)
64-
case tpe: TypeParamRef => arrayOfBounds(tpe.bounds)
65-
case tpe: Type => preErase(tpe, keepUnit = false).arrayOf()
66-
case tpe: WildcardTypeArg => arrayOfBounds(tpe.bounds)
67-
end arrayOf
54+
private def preErase(tpe: Type, keepUnit: Boolean)(using Context, SourceLanguage): PreErasedTypeRef =
55+
def arrayOf(tpe: TypeOrWildcard): ErasedTypeRef =
56+
if isGenericArrayElement(tpe) then ClassRef(defn.ObjectClass)
57+
else
58+
preErase(tpe.highIfWildcard, keepUnit = false) match
59+
case base: ErasedTypeRef => base.arrayOf()
60+
case ErasedValueClass(valueClass, _) => ClassRef(valueClass).arrayOf()
6861

6962
tpe match
7063
case tpe: AppliedType =>
@@ -73,11 +66,13 @@ private[tastyquery] object Erasure:
7366
if cls.isArray then
7467
val List(targ) = tpe.args: @unchecked
7568
arrayOf(targ)
69+
else if cls.isDerivedValueClass then preErasePolyValueClass(cls, tpe.args)
7670
else ClassRef(cls)
7771
case _ =>
7872
preErase(tpe.translucentSuperType, keepUnit)
7973
case TypeRef.OfClass(cls) =>
8074
if !keepUnit && cls.isUnit then ClassRef(defn.ErasedBoxedUnitClass)
75+
else if cls.isDerivedValueClass then preEraseMonoValueClass(cls)
8176
else ClassRef(cls)
8277
case tpe: TypeRef =>
8378
preErase(tpe.translucentSuperType, keepUnit)
@@ -90,7 +85,10 @@ private[tastyquery] object Erasure:
9085
case Some(reduced) => preErase(reduced, keepUnit)
9186
case None => preErase(tpe.bound, keepUnit)
9287
case tpe: OrType =>
93-
erasedLub(preErase(tpe.first, keepUnit = false), preErase(tpe.second, keepUnit = false))
88+
erasedLub(
89+
finishErase(preErase(tpe.first, keepUnit = false)),
90+
finishErase(preErase(tpe.second, keepUnit = false))
91+
)
9492
case tpe: AndType =>
9593
summon[SourceLanguage] match
9694
case SourceLanguage.Java =>
@@ -120,29 +118,157 @@ private[tastyquery] object Erasure:
120118
throw IllegalArgumentException(s"Unexpected type in erasure: $tpe")
121119
end preErase
122120

123-
private def finishErase(typeRef: ErasedTypeRef)(using Context): ErasedTypeRef =
121+
private def finishErase(typeRef: PreErasedTypeRef)(using Context, SourceLanguage): ErasedTypeRef =
124122
typeRef match
125-
case ClassRef(cls) =>
126-
if cls.isDerivedValueClass then finishEraseValueClass(cls)
127-
else cls.erasure
128-
case ArrayTypeRef(ClassRef(cls), dimensions) =>
129-
ArrayTypeRef(cls.erasure, dimensions)
123+
case ClassRef(cls) => cls.erasure
124+
case ArrayTypeRef(ClassRef(cls), dimensions) => ArrayTypeRef(cls.erasure, dimensions)
125+
case ErasedValueClass(_, underlying) => finishErase(underlying)
130126
end finishErase
131127

132-
private def finishEraseValueClass(cls: ClassSymbol)(using Context): ErasedTypeRef =
128+
private def preEraseMonoValueClass(cls: ClassSymbol)(using Context, SourceLanguage): ErasedValueClass =
129+
val ctor = cls.findNonOverloadedDecl(nme.Constructor)
130+
131+
val underlying = ctor.declaredType match
132+
case tpe: MethodType if tpe.paramNames.sizeIs == 1 =>
133+
tpe.paramTypes.head
134+
case _ =>
135+
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")
136+
137+
// The underlying of value classes are never value classes themselves (by language spec)
138+
val erasedUnderlying = preErase(underlying, keepUnit = false).asInstanceOf[ErasedTypeRef]
139+
140+
ErasedValueClass(cls, erasedUnderlying)
141+
end preEraseMonoValueClass
142+
143+
private def preErasePolyValueClass(cls: ClassSymbol, targs: List[TypeOrWildcard])(
144+
using Context,
145+
SourceLanguage
146+
): ErasedValueClass =
133147
val ctor = cls.findNonOverloadedDecl(nme.Constructor)
134148

135149
def illegalConstructorType(): Nothing =
136150
throw InvalidProgramStructureException(s"Illegal value class constructor type ${ctor.declaredType.showBasic}")
137151

138152
def ctorParamType(tpe: TypeOrMethodic): Type = tpe match
139153
case tpe: MethodType if tpe.paramTypes.sizeIs == 1 => tpe.paramTypes.head
140-
case tpe: MethodType => illegalConstructorType()
141-
case tpe: PolyType => ctorParamType(tpe.resultType)
142-
case tpe: Type => illegalConstructorType()
154+
case _ => illegalConstructorType()
155+
156+
val ctorPolyType = ctor.declaredType match
157+
case tpe: PolyType => tpe
158+
case _ => illegalConstructorType()
159+
160+
val genericUnderlying = ctorParamType(ctorPolyType.resultType)
161+
val specializedUnderlying = ctorParamType(ctorPolyType.instantiate(targs))
162+
163+
// The underlying of value classes are never value classes themselves (by language spec)
164+
val erasedGenericUnderlying = preErase(genericUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]
165+
val erasedSpecializedUnderlying = preErase(specializedUnderlying, keepUnit = false).asInstanceOf[ErasedTypeRef]
143166

144-
erase(ctorParamType(ctor.declaredType), ctor.sourceLanguage)
145-
end finishEraseValueClass
167+
def isPrimitive(typeRef: ErasedTypeRef): Boolean = typeRef match
168+
case ClassRef(cls) => cls.isPrimitiveValueClass
169+
case _: ArrayTypeRef => false
170+
171+
/* Ideally, we would just use `erasedSpecializedUnderlying` as the erasure of `tp`.
172+
* However, there are two special cases for polymorphic value classes, which
173+
* historically come from Scala 2:
174+
*
175+
* - Given `class Foo[A](x: A) extends AnyVal`, `Foo[X]` should erase like
176+
* `X`, except if its a primitive in which case it erases to the boxed
177+
* version of this primitive.
178+
* - Given `class Bar[A](x: Array[A]) extends AnyVal`, `Bar[X]` will be
179+
* erased like `Array[A]` as seen from its definition site, no matter
180+
* the `X` (same if `A` is bounded).
181+
*/
182+
val erasedValueClass =
183+
if isPrimitive(erasedSpecializedUnderlying) && !isPrimitive(erasedGenericUnderlying) then
184+
ClassRef(erasedSpecializedUnderlying.asInstanceOf[ClassRef].cls.boxedClass)
185+
else if genericUnderlying.baseType(defn.ArrayClass).isDefined then erasedGenericUnderlying
186+
else erasedSpecializedUnderlying
187+
188+
ErasedValueClass(cls, erasedValueClass)
189+
end preErasePolyValueClass
190+
191+
/** Is `Array[tp]` a generic Array that needs to be erased to `Object`?
192+
* This is true if among the subtypes of `Array[tp]` there is either:
193+
* - both a reference array type and a primitive array type
194+
* (e.g. `Array[_ <: Int | String]`, `Array[_ <: Any]`)
195+
* - or two different primitive array types (e.g. `Array[_ <: Int | Double]`)
196+
* In both cases the erased lub of those array types on the JVM is `Object`.
197+
*
198+
* In addition, if `isScala2` is true, we mimic the Scala 2 erasure rules and
199+
* also return true for element types upper-bounded by a non-reference type
200+
* such as in `Array[_ <: Int]` or `Array[_ <: UniversalTrait]`.
201+
*/
202+
private def isGenericArrayElement(tp: TypeOrWildcard)(using Context, SourceLanguage): Boolean =
203+
/** A symbol that represents the sort of JVM array that values of type `tp` can be stored in:
204+
* - If we can always store such values in a reference array, return `j.l.Object`.
205+
* - If we can always store them in a specific primitive array, return the corresponding primitive class.
206+
* - Otherwise, return `None`.
207+
*/
208+
def arrayUpperBound(tp: Type): Option[ClassSymbol] = tp.dealias match
209+
case TypeRef.OfClass(cls) =>
210+
def isScala2SpecialCase: Boolean =
211+
summon[SourceLanguage] == SourceLanguage.Scala2
212+
&& !cls.isNull
213+
&& !cls.isSubClass(defn.ObjectClass)
214+
215+
// Only a few classes have both primitives and references as subclasses.
216+
if cls.isAny || cls.isAnyVal || cls.isMatchable || cls.isSingleton || isScala2SpecialCase then None
217+
else if cls.isPrimitiveValueClass then Some(cls)
218+
else
219+
// Derived value classes in arrays are always boxed, so they end up here as well
220+
Some(defn.ObjectClass)
221+
222+
case tp: TypeProxy =>
223+
arrayUpperBound(tp.translucentSuperType)
224+
case tp: AndType =>
225+
arrayUpperBound(tp.first).orElse(arrayUpperBound(tp.second))
226+
case tp: OrType =>
227+
val firstBound = arrayUpperBound(tp.first)
228+
val secondBound = arrayUpperBound(tp.first)
229+
if firstBound == secondBound then firstBound
230+
else None
231+
case _: NothingType | _: AnyKindType | _: TypeLambda =>
232+
None
233+
case tp: CustomTransientGroundType =>
234+
throw IllegalArgumentException(s"Unexpected transient type: $tp")
235+
end arrayUpperBound
236+
237+
/** Can one of the JVM Array type store all possible values of type `tp`? */
238+
def fitsInJVMArray(tp: Type): Boolean = arrayUpperBound(tp).isDefined
239+
240+
tp match
241+
case tp: WildcardTypeArg =>
242+
!fitsInJVMArray(tp.bounds.high)
243+
244+
case tp: Type =>
245+
tp.dealias match
246+
case tp: TypeRef =>
247+
tp.optSymbol match
248+
case Some(cls: ClassSymbol) =>
249+
false
250+
case Some(sym: TypeMemberSymbol) if sym.isOpaqueTypeAlias =>
251+
isGenericArrayElement(tp.translucentSuperType)
252+
case _ =>
253+
tp.bounds match
254+
case TypeAlias(alias) => isGenericArrayElement(alias)
255+
case AbstractTypeBounds(_, high) => !fitsInJVMArray(high)
256+
case tp: TypeParamRef =>
257+
!fitsInJVMArray(tp)
258+
case tp: MatchType =>
259+
val cases = tp.cases
260+
cases.nonEmpty && !fitsInJVMArray(cases.map(_.result).reduce(OrType(_, _)))
261+
case tp: TypeProxy =>
262+
isGenericArrayElement(tp.translucentSuperType)
263+
case tp: AndType =>
264+
isGenericArrayElement(tp.first) && isGenericArrayElement(tp.second)
265+
case tp: OrType =>
266+
isGenericArrayElement(tp.first) || isGenericArrayElement(tp.second)
267+
case _: NothingType | _: AnyKindType | _: TypeLambda =>
268+
false
269+
case tp: CustomTransientGroundType =>
270+
throw IllegalArgumentException(s"Unexpected transient type: $tp")
271+
end isGenericArrayElement
146272

147273
/** The erased least upper bound of two erased types is computed as follows.
148274
*
@@ -224,7 +350,7 @@ private[tastyquery] object Erasure:
224350
* - Associativity and commutativity, because this method acts as the minimum
225351
* of the total order induced by `compareErasedGlb`.
226352
*/
227-
private def erasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): ErasedTypeRef =
353+
private def erasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): PreErasedTypeRef =
228354
if compareErasedGlb(tp1, tp2) <= 0 then tp1
229355
else tp2
230356

@@ -248,7 +374,7 @@ private[tastyquery] object Erasure:
248374
*
249375
* @see erasedGlb
250376
*/
251-
private def compareErasedGlb(tp1: ErasedTypeRef, tp2: ErasedTypeRef)(using Context): Int =
377+
private def compareErasedGlb(tp1: PreErasedTypeRef, tp2: PreErasedTypeRef)(using Context): Int =
252378
def compareClasses(cls1: ClassSymbol, cls2: ClassSymbol): Int =
253379
if cls1.isSubClass(cls2) then -1
254380
else if cls2.isSubClass(cls1) then 1
@@ -260,13 +386,11 @@ private[tastyquery] object Erasure:
260386
// fast path
261387
0
262388

263-
case (ClassRef(cls1), _) if cls1.isDerivedValueClass =>
264-
tp2 match
265-
case ClassRef(cls2) if cls2.isDerivedValueClass =>
266-
compareClasses(cls1, cls2)
267-
case _ =>
268-
-1
269-
case (_, ClassRef(cls2)) if cls2.isDerivedValueClass =>
389+
case (ErasedValueClass(cls1, _), ErasedValueClass(cls2, _)) =>
390+
compareClasses(cls1, cls2)
391+
case (ErasedValueClass(cls1, _), _) =>
392+
-1
393+
case (_, ErasedValueClass(cls2, _)) =>
270394
1
271395

272396
case (tp1: ArrayTypeRef, tp2: ArrayTypeRef) =>

0 commit comments

Comments
 (0)