@@ -11,10 +11,6 @@ import tastyquery.Types.*
1111import tastyquery .Types .ErasedTypeRef .*
1212
1313private [tastyquery] object Erasure :
14- // TODO: improve this to match dotty:
15- // - use correct type erasure algorithm from Scala 3, with specialisations
16- // for Java types and Scala 2 types (i.e. varargs, value-classes)
17-
1814 @ deprecated(" use the overload that takes an explicit SourceLanguage" , since = " 0.7.1" )
1915 def erase (tpe : Type )(using Context ): ErasedTypeRef =
2016 erase(tpe, SourceLanguage .Scala3 )
@@ -27,44 +23,41 @@ private[tastyquery] object Erasure:
2723 finishErase(preErase(tpe, keepUnit))
2824 end erase
2925
30- /** First pass of erasure, where some special types are preserved as is.
26+ private [tastyquery] def eraseForSigName (tpe : Type , language : SourceLanguage , keepUnit : Boolean )(
27+ using Context
28+ ): ErasedTypeRef =
29+ given SourceLanguage = language
30+
31+ val patchedPreErased = preErase(tpe, keepUnit) match
32+ case ArrayTypeRef (ClassRef (cls), dimensions) if cls.isDerivedValueClass =>
33+ // Hack! dotc's `sigName` does *not* correspond to erasure in this case!
34+ val patchedBase =
35+ if cls.typeParams.isEmpty then preEraseMonoValueClass(cls)
36+ else preErasePolyValueClass(cls, cls.typeParams.map(_.localRef))
37+ patchedBase.underlying.multiArrayOf(dimensions)
38+ case typeRef =>
39+ typeRef
40+
41+ finishErase(patchedPreErased)
42+ end eraseForSigName
43+
44+ private final case class ErasedValueClass (valueClass : ClassSymbol , underlying : ErasedTypeRef )
45+
46+ private type PreErasedTypeRef = ErasedTypeRef | ErasedValueClass
47+
48+ /** First pass of erasure, where some special types are preserved as is,
49+ * and where value classes become `ErasedValueClass`es.
3150 *
3251 * In particular, `Any` is preserved as `Any`, instead of becoming
3352 * `java.lang.Object`.
3453 */
35- private def preErase (tpe : Type , keepUnit : Boolean )(using Context , SourceLanguage ): ErasedTypeRef =
36- def arrayOfBounds (bounds : TypeBounds ): ErasedTypeRef =
37- preErase(bounds.high, keepUnit = false ) match
38- case ClassRef (cls) if cls.isAny || cls.isAnyVal =>
39- ClassRef (defn.ObjectClass )
40- case typeRef =>
41- typeRef.arrayOf()
42-
43- def arrayOf (tpe : TypeOrWildcard ): ErasedTypeRef = tpe match
44- case tpe : AppliedType =>
45- tpe.tycon match
46- case TypeRef .OfClass (cls) =>
47- if cls.isArray then
48- val List (targ) = tpe.args: @ unchecked
49- arrayOf(targ).arrayOf()
50- else ClassRef (cls).arrayOf()
51- case _ =>
52- arrayOf(tpe.translucentSuperType)
53- case TypeRef .OfClass (cls) =>
54- if cls.isUnit then ClassRef (defn.ErasedBoxedUnitClass ).arrayOf()
55- else ClassRef (cls).arrayOf()
56- case tpe : TypeRef =>
57- tpe.optSymbol match
58- case Some (sym : TypeMemberSymbol ) if sym.isOpaqueTypeAlias =>
59- arrayOf(tpe.translucentSuperType)
60- case _ =>
61- tpe.bounds match
62- case bounds : AbstractTypeBounds => arrayOfBounds(bounds)
63- case TypeAlias (alias) => arrayOf(alias)
64- case tpe : TypeParamRef => arrayOfBounds(tpe.bounds)
65- case tpe : Type => preErase(tpe, keepUnit = false ).arrayOf()
66- case tpe : WildcardTypeArg => arrayOfBounds(tpe.bounds)
67- end arrayOf
54+ private def preErase (tpe : Type , keepUnit : Boolean )(using Context , SourceLanguage ): PreErasedTypeRef =
55+ def arrayOf (tpe : TypeOrWildcard ): ErasedTypeRef =
56+ if isGenericArrayElement(tpe) then ClassRef (defn.ObjectClass )
57+ else
58+ preErase(tpe.highIfWildcard, keepUnit = false ) match
59+ case base : ErasedTypeRef => base.arrayOf()
60+ case ErasedValueClass (valueClass, _) => ClassRef (valueClass).arrayOf()
6861
6962 tpe match
7063 case tpe : AppliedType =>
@@ -73,11 +66,13 @@ private[tastyquery] object Erasure:
7366 if cls.isArray then
7467 val List (targ) = tpe.args: @ unchecked
7568 arrayOf(targ)
69+ else if cls.isDerivedValueClass then preErasePolyValueClass(cls, tpe.args)
7670 else ClassRef (cls)
7771 case _ =>
7872 preErase(tpe.translucentSuperType, keepUnit)
7973 case TypeRef .OfClass (cls) =>
8074 if ! keepUnit && cls.isUnit then ClassRef (defn.ErasedBoxedUnitClass )
75+ else if cls.isDerivedValueClass then preEraseMonoValueClass(cls)
8176 else ClassRef (cls)
8277 case tpe : TypeRef =>
8378 preErase(tpe.translucentSuperType, keepUnit)
@@ -90,7 +85,10 @@ private[tastyquery] object Erasure:
9085 case Some (reduced) => preErase(reduced, keepUnit)
9186 case None => preErase(tpe.bound, keepUnit)
9287 case tpe : OrType =>
93- erasedLub(preErase(tpe.first, keepUnit = false ), preErase(tpe.second, keepUnit = false ))
88+ erasedLub(
89+ finishErase(preErase(tpe.first, keepUnit = false )),
90+ finishErase(preErase(tpe.second, keepUnit = false ))
91+ )
9492 case tpe : AndType =>
9593 summon[SourceLanguage ] match
9694 case SourceLanguage .Java =>
@@ -120,29 +118,157 @@ private[tastyquery] object Erasure:
120118 throw IllegalArgumentException (s " Unexpected type in erasure: $tpe" )
121119 end preErase
122120
123- private def finishErase (typeRef : ErasedTypeRef )(using Context ): ErasedTypeRef =
121+ private def finishErase (typeRef : PreErasedTypeRef )(using Context , SourceLanguage ): ErasedTypeRef =
124122 typeRef match
125- case ClassRef (cls) =>
126- if cls.isDerivedValueClass then finishEraseValueClass(cls)
127- else cls.erasure
128- case ArrayTypeRef (ClassRef (cls), dimensions) =>
129- ArrayTypeRef (cls.erasure, dimensions)
123+ case ClassRef (cls) => cls.erasure
124+ case ArrayTypeRef (ClassRef (cls), dimensions) => ArrayTypeRef (cls.erasure, dimensions)
125+ case ErasedValueClass (_, underlying) => finishErase(underlying)
130126 end finishErase
131127
132- private def finishEraseValueClass (cls : ClassSymbol )(using Context ): ErasedTypeRef =
128+ private def preEraseMonoValueClass (cls : ClassSymbol )(using Context , SourceLanguage ): ErasedValueClass =
129+ val ctor = cls.findNonOverloadedDecl(nme.Constructor )
130+
131+ val underlying = ctor.declaredType match
132+ case tpe : MethodType if tpe.paramNames.sizeIs == 1 =>
133+ tpe.paramTypes.head
134+ case _ =>
135+ throw InvalidProgramStructureException (s " Illegal value class constructor type ${ctor.declaredType.showBasic}" )
136+
137+ // The underlying of value classes are never value classes themselves (by language spec)
138+ val erasedUnderlying = preErase(underlying, keepUnit = false ).asInstanceOf [ErasedTypeRef ]
139+
140+ ErasedValueClass (cls, erasedUnderlying)
141+ end preEraseMonoValueClass
142+
143+ private def preErasePolyValueClass (cls : ClassSymbol , targs : List [TypeOrWildcard ])(
144+ using Context ,
145+ SourceLanguage
146+ ): ErasedValueClass =
133147 val ctor = cls.findNonOverloadedDecl(nme.Constructor )
134148
135149 def illegalConstructorType (): Nothing =
136150 throw InvalidProgramStructureException (s " Illegal value class constructor type ${ctor.declaredType.showBasic}" )
137151
138152 def ctorParamType (tpe : TypeOrMethodic ): Type = tpe match
139153 case tpe : MethodType if tpe.paramTypes.sizeIs == 1 => tpe.paramTypes.head
140- case tpe : MethodType => illegalConstructorType()
141- case tpe : PolyType => ctorParamType(tpe.resultType)
142- case tpe : Type => illegalConstructorType()
154+ case _ => illegalConstructorType()
155+
156+ val ctorPolyType = ctor.declaredType match
157+ case tpe : PolyType => tpe
158+ case _ => illegalConstructorType()
159+
160+ val genericUnderlying = ctorParamType(ctorPolyType.resultType)
161+ val specializedUnderlying = ctorParamType(ctorPolyType.instantiate(targs))
162+
163+ // The underlying of value classes are never value classes themselves (by language spec)
164+ val erasedGenericUnderlying = preErase(genericUnderlying, keepUnit = false ).asInstanceOf [ErasedTypeRef ]
165+ val erasedSpecializedUnderlying = preErase(specializedUnderlying, keepUnit = false ).asInstanceOf [ErasedTypeRef ]
143166
144- erase(ctorParamType(ctor.declaredType), ctor.sourceLanguage)
145- end finishEraseValueClass
167+ def isPrimitive (typeRef : ErasedTypeRef ): Boolean = typeRef match
168+ case ClassRef (cls) => cls.isPrimitiveValueClass
169+ case _ : ArrayTypeRef => false
170+
171+ /* Ideally, we would just use `erasedSpecializedUnderlying` as the erasure of `tp`.
172+ * However, there are two special cases for polymorphic value classes, which
173+ * historically come from Scala 2:
174+ *
175+ * - Given `class Foo[A](x: A) extends AnyVal`, `Foo[X]` should erase like
176+ * `X`, except if its a primitive in which case it erases to the boxed
177+ * version of this primitive.
178+ * - Given `class Bar[A](x: Array[A]) extends AnyVal`, `Bar[X]` will be
179+ * erased like `Array[A]` as seen from its definition site, no matter
180+ * the `X` (same if `A` is bounded).
181+ */
182+ val erasedValueClass =
183+ if isPrimitive(erasedSpecializedUnderlying) && ! isPrimitive(erasedGenericUnderlying) then
184+ ClassRef (erasedSpecializedUnderlying.asInstanceOf [ClassRef ].cls.boxedClass)
185+ else if genericUnderlying.baseType(defn.ArrayClass ).isDefined then erasedGenericUnderlying
186+ else erasedSpecializedUnderlying
187+
188+ ErasedValueClass (cls, erasedValueClass)
189+ end preErasePolyValueClass
190+
191+ /** Is `Array[tp]` a generic Array that needs to be erased to `Object`?
192+ * This is true if among the subtypes of `Array[tp]` there is either:
193+ * - both a reference array type and a primitive array type
194+ * (e.g. `Array[_ <: Int | String]`, `Array[_ <: Any]`)
195+ * - or two different primitive array types (e.g. `Array[_ <: Int | Double]`)
196+ * In both cases the erased lub of those array types on the JVM is `Object`.
197+ *
198+ * In addition, if `isScala2` is true, we mimic the Scala 2 erasure rules and
199+ * also return true for element types upper-bounded by a non-reference type
200+ * such as in `Array[_ <: Int]` or `Array[_ <: UniversalTrait]`.
201+ */
202+ private def isGenericArrayElement (tp : TypeOrWildcard )(using Context , SourceLanguage ): Boolean =
203+ /** A symbol that represents the sort of JVM array that values of type `tp` can be stored in:
204+ * - If we can always store such values in a reference array, return `j.l.Object`.
205+ * - If we can always store them in a specific primitive array, return the corresponding primitive class.
206+ * - Otherwise, return `None`.
207+ */
208+ def arrayUpperBound (tp : Type ): Option [ClassSymbol ] = tp.dealias match
209+ case TypeRef .OfClass (cls) =>
210+ def isScala2SpecialCase : Boolean =
211+ summon[SourceLanguage ] == SourceLanguage .Scala2
212+ && ! cls.isNull
213+ && ! cls.isSubClass(defn.ObjectClass )
214+
215+ // Only a few classes have both primitives and references as subclasses.
216+ if cls.isAny || cls.isAnyVal || cls.isMatchable || cls.isSingleton || isScala2SpecialCase then None
217+ else if cls.isPrimitiveValueClass then Some (cls)
218+ else
219+ // Derived value classes in arrays are always boxed, so they end up here as well
220+ Some (defn.ObjectClass )
221+
222+ case tp : TypeProxy =>
223+ arrayUpperBound(tp.translucentSuperType)
224+ case tp : AndType =>
225+ arrayUpperBound(tp.first).orElse(arrayUpperBound(tp.second))
226+ case tp : OrType =>
227+ val firstBound = arrayUpperBound(tp.first)
228+ val secondBound = arrayUpperBound(tp.first)
229+ if firstBound == secondBound then firstBound
230+ else None
231+ case _ : NothingType | _ : AnyKindType | _ : TypeLambda =>
232+ None
233+ case tp : CustomTransientGroundType =>
234+ throw IllegalArgumentException (s " Unexpected transient type: $tp" )
235+ end arrayUpperBound
236+
237+ /** Can one of the JVM Array type store all possible values of type `tp`? */
238+ def fitsInJVMArray (tp : Type ): Boolean = arrayUpperBound(tp).isDefined
239+
240+ tp match
241+ case tp : WildcardTypeArg =>
242+ ! fitsInJVMArray(tp.bounds.high)
243+
244+ case tp : Type =>
245+ tp.dealias match
246+ case tp : TypeRef =>
247+ tp.optSymbol match
248+ case Some (cls : ClassSymbol ) =>
249+ false
250+ case Some (sym : TypeMemberSymbol ) if sym.isOpaqueTypeAlias =>
251+ isGenericArrayElement(tp.translucentSuperType)
252+ case _ =>
253+ tp.bounds match
254+ case TypeAlias (alias) => isGenericArrayElement(alias)
255+ case AbstractTypeBounds (_, high) => ! fitsInJVMArray(high)
256+ case tp : TypeParamRef =>
257+ ! fitsInJVMArray(tp)
258+ case tp : MatchType =>
259+ val cases = tp.cases
260+ cases.nonEmpty && ! fitsInJVMArray(cases.map(_.result).reduce(OrType (_, _)))
261+ case tp : TypeProxy =>
262+ isGenericArrayElement(tp.translucentSuperType)
263+ case tp : AndType =>
264+ isGenericArrayElement(tp.first) && isGenericArrayElement(tp.second)
265+ case tp : OrType =>
266+ isGenericArrayElement(tp.first) || isGenericArrayElement(tp.second)
267+ case _ : NothingType | _ : AnyKindType | _ : TypeLambda =>
268+ false
269+ case tp : CustomTransientGroundType =>
270+ throw IllegalArgumentException (s " Unexpected transient type: $tp" )
271+ end isGenericArrayElement
146272
147273 /** The erased least upper bound of two erased types is computed as follows.
148274 *
@@ -224,7 +350,7 @@ private[tastyquery] object Erasure:
224350 * - Associativity and commutativity, because this method acts as the minimum
225351 * of the total order induced by `compareErasedGlb`.
226352 */
227- private def erasedGlb (tp1 : ErasedTypeRef , tp2 : ErasedTypeRef )(using Context ): ErasedTypeRef =
353+ private def erasedGlb (tp1 : PreErasedTypeRef , tp2 : PreErasedTypeRef )(using Context ): PreErasedTypeRef =
228354 if compareErasedGlb(tp1, tp2) <= 0 then tp1
229355 else tp2
230356
@@ -248,7 +374,7 @@ private[tastyquery] object Erasure:
248374 *
249375 * @see erasedGlb
250376 */
251- private def compareErasedGlb (tp1 : ErasedTypeRef , tp2 : ErasedTypeRef )(using Context ): Int =
377+ private def compareErasedGlb (tp1 : PreErasedTypeRef , tp2 : PreErasedTypeRef )(using Context ): Int =
252378 def compareClasses (cls1 : ClassSymbol , cls2 : ClassSymbol ): Int =
253379 if cls1.isSubClass(cls2) then - 1
254380 else if cls2.isSubClass(cls1) then 1
@@ -260,13 +386,11 @@ private[tastyquery] object Erasure:
260386 // fast path
261387 0
262388
263- case (ClassRef (cls1), _) if cls1.isDerivedValueClass =>
264- tp2 match
265- case ClassRef (cls2) if cls2.isDerivedValueClass =>
266- compareClasses(cls1, cls2)
267- case _ =>
268- - 1
269- case (_, ClassRef (cls2)) if cls2.isDerivedValueClass =>
389+ case (ErasedValueClass (cls1, _), ErasedValueClass (cls2, _)) =>
390+ compareClasses(cls1, cls2)
391+ case (ErasedValueClass (cls1, _), _) =>
392+ - 1
393+ case (_, ErasedValueClass (cls2, _)) =>
270394 1
271395
272396 case (tp1 : ArrayTypeRef , tp2 : ArrayTypeRef ) =>
0 commit comments