diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 70a77c9560b2..c5698178f628 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -28,6 +28,7 @@ object Feature: val dependent = experimental("dependent") val erasedDefinitions = experimental("erasedDefinitions") + val strictEqualityPatternMatching = experimental("strictEqualityPatternMatching") val symbolLiterals = deprecated("symbolLiterals") val saferExceptions = experimental("saferExceptions") val pureFunctions = experimental("pureFunctions") @@ -58,6 +59,7 @@ object Feature: (scala2macros, "Allow Scala 2 macros"), (dependent, "Allow dependent method types"), (erasedDefinitions, "Allow erased definitions"), + (strictEqualityPatternMatching, "relaxed CanEqual checks for ADT pattern matching"), (symbolLiterals, "Allow symbol literals"), (saferExceptions, "Enable safer exceptions"), (pureFunctions, "Enable pure functions for capture checking"), diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 290e061772e4..322419e3f6cf 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1338,7 +1338,7 @@ trait Applications extends Compatibility { case Apply(fn @ Select(left, _), right :: Nil) if fn.hasType => val op = fn.symbol if (op == defn.Any_== || op == defn.Any_!=) - checkCanEqual(left.tpe.widen, right.tpe.widen, app.span) + checkCanEqual(left, right.tpe.widen, app.span) case _ => } app diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index fbdf6ab80bbf..4a6e92bb9849 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -84,6 +84,9 @@ object Implicits: def strictEquality(using Context): Boolean = ctx.mode.is(Mode.StrictEquality) || Feature.enabled(nme.strictEquality) + def strictEqualityPatternMatching(using Context): Boolean = + Feature.enabled(Feature.strictEqualityPatternMatching) + /** A common base class of contextual implicits and of-type implicits which * represents a set of references to implicit definitions. @@ -1039,7 +1042,7 @@ trait Implicits: * - if one of T, U is a subtype of the lifted version of the other, * unless strict equality is set. */ - def assumedCanEqual(ltp: Type, rtp: Type)(using Context) = { + def assumedCanEqual(leftTreeOption: Option[Tree], ltp: Type, rtp: Type)(using Context): Boolean = { // Map all non-opaque abstract types to their upper bound. // This is done to check whether such types might plausibly be comparable to each other. val lift = new TypeMap { @@ -1062,15 +1065,20 @@ trait Implicits: ltp.isError || rtp.isError - || !strictEquality && (ltp <:< lift(rtp) || rtp <:< lift(ltp)) + || locally: + if strictEquality then + strictEqualityPatternMatching && leftTreeOption.exists: leftTree => + ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case)) + else + (ltp <:< lift(rtp) || rtp <:< lift(ltp)) } /** Check that equality tests between types `ltp` and `rtp` make sense */ - def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = - if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) { + def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = + val ltp = left.tpe.widen + if !ctx.isAfterTyper && !assumedCanEqual(Some(left), ltp, rtp) then val res = implicitArgTree(defn.CanEqualClass.typeRef.appliedTo(ltp, rtp), span) implicits.println(i"CanEqual witness found for $ltp / $rtp: $res: ${res.tpe}") - } object hasSkolem extends TreeAccumulator[Boolean]: def apply(x: Boolean, tree: Tree)(using Context): Boolean = diff --git a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala index ed8919661860..8400639706d6 100644 --- a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala +++ b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala @@ -175,7 +175,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult = Implicits.NoMatchingImplicitsFailure - override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = () + override def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = () override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 3b114de6a05c..57c179a77f74 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -183,7 +183,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): * one of `tp1`, `tp2` has a reflexive `CanEqual` instance. */ def validEqAnyArgs(tp1: Type, tp2: Type)(using Context) = - typer.assumedCanEqual(tp1, tp2) + typer.assumedCanEqual(None, tp1, tp2) || withMode(Mode.StrictEquality) { !hasEq(tp1) && !hasEq(tp2) } diff --git a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala new file mode 100644 index 000000000000..8c6173a886d0 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala @@ -0,0 +1,45 @@ +package dotty.tools.dotc.typer + +import dotty.tools.DottyTest +import dotty.tools.dotc.core.Contexts.* + +import org.junit.Test +import org.junit.Assert.fail + +class SIP67Tests extends DottyTest: + + private def checkNoErrors(source: String): Unit = + val ctx = checkCompile("typer", source)((_, _) => ()) + if ctx.reporter.hasErrors then + fail("Unexpected compilation errors were reported") + + @Test + def sip67test1: Unit = + checkNoErrors: + """ + import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching + enum Foo: + case Bar + + val _ = + (??? : Foo) match + case Foo.Bar => + """ + @Test + def sip67test2: Unit = + checkNoErrors: + """ + import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching + + sealed trait Foo + + object Foo: + case object Bar extends Foo + + val _ = + (??? : Foo) match + case Foo.Bar => + """ + \ No newline at end of file diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 9d38ea4371ff..5d13ea17f2b7 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -50,6 +50,13 @@ object language: @compileTimeOnly("`erasedDefinitions` can only be used at compile time in import statements") object erasedDefinitions + /** Experimental support for relaxed CanEqual checks for ADT pattern matching + * + * @see [[https://github.com/scala/improvement-proposals/pull/97]] + */ + @compileTimeOnly("`strictEqualityPatternMatching` can only be used at compile time in import statements") + object strictEqualityPatternMatching + /** Experimental support for using indentation for arguments */ @compileTimeOnly("`fewerBraces` can only be used at compile time in import statements")