@@ -60,6 +60,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
60
60
/** Indicates whether the subtype check used GADT bounds */
61
61
private var GADTused : Boolean = false
62
62
63
+ protected var canWidenAbstract : Boolean = true
64
+
63
65
private var myInstance : TypeComparer = this
64
66
def currentInstance : TypeComparer = myInstance
65
67
@@ -757,9 +759,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
757
759
758
760
def tryBaseType (cls2 : Symbol ) = {
759
761
val base = nonExprBaseType(tp1, cls2)
760
- if (base.exists && (base `ne` tp1))
761
- isSubType(base, tp2, if (tp1.isRef(cls2)) approx else approx.addLow) ||
762
- base.isInstanceOf [OrType ] && fourthTry
762
+ if base.exists && (base ne tp1)
763
+ && (! caseLambda.exists || canWidenAbstract || tp1.widen.underlyingClassRef(refinementOK = true ).exists)
764
+ then
765
+ isSubType(base, tp2, if (tp1.isRef(cls2)) approx else approx.addLow)
766
+ || base.isInstanceOf [OrType ] && fourthTry
763
767
// if base is a disjunction, this might have come from a tp1 type that
764
768
// expands to a match type. In this case, we should try to reduce the type
765
769
// and compare the redux. This is done in fourthTry
@@ -776,7 +780,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
776
780
|| narrowGADTBounds(tp1, tp2, approx, isUpper = true ))
777
781
&& (tp2.isAny || GADTusage (tp1.symbol))
778
782
779
- isSubType(hi1, tp2, approx.addLow) || compareGADT || tryLiftedToThis1
783
+ (! caseLambda.exists || canWidenAbstract) && isSubType(hi1, tp2, approx.addLow)
784
+ || compareGADT
785
+ || tryLiftedToThis1
780
786
case _ =>
781
787
// `Mode.RelaxedOverriding` is only enabled when checking Java overriding
782
788
// in explicit nulls, and `Null` becomes a bottom type, which allows
@@ -2849,7 +2855,16 @@ object TypeComparer {
2849
2855
comparing(_.tracked(op))
2850
2856
}
2851
2857
2858
+ object TrackingTypeComparer :
2859
+ enum MatchResult :
2860
+ case Reduced (tp : Type )
2861
+ case Disjoint
2862
+ case Stuck
2863
+ case NoInstance (param : Name , bounds : TypeBounds )
2864
+
2852
2865
class TrackingTypeComparer (initctx : Context ) extends TypeComparer (initctx) {
2866
+ import TrackingTypeComparer .*
2867
+
2853
2868
init(initctx)
2854
2869
2855
2870
override def trackingTypeComparer = this
@@ -2887,15 +2902,25 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
2887
2902
}
2888
2903
2889
2904
def matchCases (scrut : Type , cases : List [Type ])(using Context ): Type = {
2890
- def paramInstances = new TypeAccumulator [Array [Type ]] {
2891
- def apply (inst : Array [Type ], t : Type ) = t match {
2892
- case t @ TypeParamRef (b, n) if b `eq` caseLambda =>
2893
- inst(n) = approximation(t, fromBelow = variance >= 0 ).simplified
2894
- inst
2905
+
2906
+ def paramInstances (canApprox : Boolean ) = new TypeAccumulator [Array [Type ]]:
2907
+ def apply (insts : Array [Type ], t : Type ) = t match
2908
+ case param @ TypeParamRef (b, n) if b eq caseLambda =>
2909
+ insts(n) = {
2910
+ if canApprox then
2911
+ approximation(param, fromBelow = variance >= 0 )
2912
+ else constraint.entry(param) match
2913
+ case entry : TypeBounds =>
2914
+ val lo = fullLowerBound(param)
2915
+ val hi = fullUpperBound(param)
2916
+ if isSubType(hi, lo) then lo else TypeBounds (lo, hi)
2917
+ case inst =>
2918
+ assert(inst.exists, i " param = $param\n constraint = $constraint" )
2919
+ inst
2920
+ }.simplified
2921
+ insts
2895
2922
case _ =>
2896
- foldOver(inst, t)
2897
- }
2898
- }
2923
+ foldOver(insts, t)
2899
2924
2900
2925
def instantiateParams (inst : Array [Type ]) = new TypeMap {
2901
2926
def apply (t : Type ) = t match {
@@ -2911,7 +2936,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
2911
2936
* None if the match fails and we should consider the following cases
2912
2937
* because scrutinee and pattern do not overlap
2913
2938
*/
2914
- def matchCase (cas : Type ): Option [ Type ] = trace(i " match case $cas vs $scrut" , matchTypes) {
2939
+ def matchCase (cas : Type ): MatchResult = trace(i " match case $cas vs $scrut" , matchTypes) {
2915
2940
val cas1 = cas match {
2916
2941
case cas : HKTypeLambda =>
2917
2942
caseLambda = constrained(cas)
@@ -2922,34 +2947,48 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
2922
2947
2923
2948
val defn .MatchCase (pat, body) = cas1 : @ unchecked
2924
2949
2925
- if (isSubType(scrut, pat))
2926
- // `scrut` is a subtype of `pat`: *It's a Match!*
2927
- Some {
2928
- caseLambda match {
2929
- case caseLambda : HKTypeLambda =>
2930
- val instances = paramInstances(new Array (caseLambda.paramNames.length), pat)
2931
- instantiateParams(instances)(body).simplified
2932
- case _ =>
2933
- body
2934
- }
2935
- }
2950
+ def matches (canWidenAbstract : Boolean ): Boolean =
2951
+ val saved = this .canWidenAbstract
2952
+ this .canWidenAbstract = canWidenAbstract
2953
+ try necessarySubType(scrut, pat)
2954
+ finally this .canWidenAbstract = saved
2955
+
2956
+ def redux (canApprox : Boolean ): MatchResult =
2957
+ caseLambda match
2958
+ case caseLambda : HKTypeLambda =>
2959
+ val instances = paramInstances(canApprox)(new Array (caseLambda.paramNames.length), pat)
2960
+ instances.indices.find(instances(_).isInstanceOf [TypeBounds ]) match
2961
+ case Some (i) if ! canApprox =>
2962
+ MatchResult .NoInstance (caseLambda.paramNames(i), instances(i).bounds)
2963
+ case _ =>
2964
+ MatchResult .Reduced (instantiateParams(instances)(body).simplified)
2965
+ case _ =>
2966
+ MatchResult .Reduced (body)
2967
+
2968
+ if caseLambda.exists && matches(canWidenAbstract = false ) then
2969
+ redux(canApprox = true )
2970
+ else if matches(canWidenAbstract = true ) then
2971
+ redux(canApprox = false )
2936
2972
else if (provablyDisjoint(scrut, pat))
2937
2973
// We found a proof that `scrut` and `pat` are incompatible.
2938
2974
// The search continues.
2939
- None
2975
+ MatchResult . Disjoint
2940
2976
else
2941
- Some ( NoType )
2977
+ MatchResult . Stuck
2942
2978
}
2943
2979
2944
2980
def recur (remaining : List [Type ]): Type = remaining match
2945
2981
case cas :: remaining1 =>
2946
2982
matchCase(cas) match
2947
- case None =>
2983
+ case MatchResult . Disjoint =>
2948
2984
recur(remaining1)
2949
- case Some ( NoType ) =>
2985
+ case MatchResult . Stuck =>
2950
2986
MatchTypeTrace .stuck(scrut, cas, remaining1)
2951
2987
NoType
2952
- case Some (tp) =>
2988
+ case MatchResult .NoInstance (pname, bounds) =>
2989
+ MatchTypeTrace .noInstance(scrut, cas, pname, bounds)
2990
+ NoType
2991
+ case MatchResult .Reduced (tp) =>
2953
2992
tp
2954
2993
case Nil =>
2955
2994
val casesText = MatchTypeTrace .noMatchesText(scrut, cases)
0 commit comments