@@ -26,6 +26,7 @@ import cc.*
2626import Capabilities .Capability
2727import NameKinds .WildcardParamName
2828import MatchTypes .isConcrete
29+ import scala .util .boundary , boundary .break
2930
3031/** Provides methods to compare types.
3132 */
@@ -2090,6 +2091,45 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20902091 else op2
20912092 end necessaryEither
20922093
2094+ /** Finds the necessary (the weakest) GADT constraint among a list of them.
2095+ * It returns the one being subsumed by all others if exists, and `None` otherwise.
2096+ *
2097+ * This is used when typechecking pattern alternatives, for instance:
2098+ *
2099+ * enum Expr[+T]:
2100+ * case I1(x: Int) extends Expr[Int]
2101+ * case I2(x: Int) extends Expr[Int]
2102+ * case B(x: Boolean) extends Expr[Boolean]
2103+ * import Expr.*
2104+ *
2105+ * The following function should compile:
2106+ *
2107+ * def foo[T](e: Expr[T]): T = e match
2108+ * case I1(_) | I2(_) => 42
2109+ *
2110+ * since `T >: Int` is subsumed by both alternatives in the first match clause.
2111+ *
2112+ * However, the following should not:
2113+ *
2114+ * def foo[T](e: Expr[T]): T = e match
2115+ * case I1(_) | B(_) => 42
2116+ *
2117+ * since the `I1(_)` case gives the constraint `T >: Int` while `B(_)` gives `T >: Boolean`.
2118+ * Neither of the constraints is subsumed by the other.
2119+ */
2120+ def necessaryGadtConstraint (constrs : List [GadtConstraint ], preGadt : GadtConstraint )(using Context ): Option [GadtConstraint ] = boundary :
2121+ constrs match
2122+ case Nil => break(None )
2123+ case c0 :: constrs =>
2124+ var weakest = c0
2125+ for c <- constrs do
2126+ if subsumes(weakest.constraint, c.constraint, preGadt.constraint) then
2127+ weakest = c
2128+ else if ! subsumes(c.constraint, weakest.constraint, preGadt.constraint) then
2129+ // this two constraints are disjoint
2130+ break(None )
2131+ break(Some (weakest))
2132+
20932133 inline def rollbackConstraintsUnless (inline op : Boolean ): Boolean =
20942134 val saved = constraint
20952135 var result = false
@@ -3449,6 +3489,9 @@ object TypeComparer {
34493489 def constrainPatternType (pat : Type , scrut : Type , forceInvariantRefinement : Boolean = false )(using Context ): Boolean =
34503490 comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
34513491
3492+ def necessaryGadtConstraint (constrs : List [GadtConstraint ], preGadt : GadtConstraint )(using Context ): Option [GadtConstraint ] =
3493+ comparing(_.necessaryGadtConstraint(constrs, preGadt))
3494+
34523495 def explained [T ](op : ExplainingTypeComparer => T , header : String = " Subtype trace:" , short : Boolean = false )(using Context ): String =
34533496 comparing(_.explained(op, header, short))
34543497
0 commit comments