Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Feature:

val dependent = experimental("dependent")
val erasedDefinitions = experimental("erasedDefinitions")
val strictEqualityPatternMatching = experimental("strictEqualityPatternMatching")
val symbolLiterals = deprecated("symbolLiterals")
val saferExceptions = experimental("saferExceptions")
val pureFunctions = experimental("pureFunctions")
Expand Down Expand Up @@ -58,6 +59,7 @@ object Feature:
(scala2macros, "Allow Scala 2 macros"),
(dependent, "Allow dependent method types"),
(erasedDefinitions, "Allow erased definitions"),
(strictEqualityPatternMatching, "relaxed CanEqual checks for ADT pattern matching"),
(symbolLiterals, "Allow symbol literals"),
(saferExceptions, "Enable safer exceptions"),
(pureFunctions, "Enable pure functions for capture checking"),
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ trait Applications extends Compatibility {
case Apply(fn @ Select(left, _), right :: Nil) if fn.hasType =>
val op = fn.symbol
if (op == defn.Any_== || op == defn.Any_!=)
checkCanEqual(left.tpe.widen, right.tpe.widen, app.span)
checkCanEqual(left, right.tpe.widen, app.span)
case _ =>
}
app
Expand Down
18 changes: 13 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ object Implicits:
def strictEquality(using Context): Boolean =
ctx.mode.is(Mode.StrictEquality) || Feature.enabled(nme.strictEquality)

def strictEqualityPatternMatching(using Context): Boolean =
Feature.enabled(Feature.strictEqualityPatternMatching)


/** A common base class of contextual implicits and of-type implicits which
* represents a set of references to implicit definitions.
Expand Down Expand Up @@ -1039,7 +1042,7 @@ trait Implicits:
* - if one of T, U is a subtype of the lifted version of the other,
* unless strict equality is set.
*/
def assumedCanEqual(ltp: Type, rtp: Type)(using Context) = {
def assumedCanEqual(leftTreeOption: Option[Tree], ltp: Type, rtp: Type)(using Context): Boolean = {
// Map all non-opaque abstract types to their upper bound.
// This is done to check whether such types might plausibly be comparable to each other.
val lift = new TypeMap {
Expand All @@ -1062,15 +1065,20 @@ trait Implicits:

ltp.isError
|| rtp.isError
|| !strictEquality && (ltp <:< lift(rtp) || rtp <:< lift(ltp))
|| locally:
if strictEquality then
strictEqualityPatternMatching && leftTreeOption.exists: leftTree =>
ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case))
else
(ltp <:< lift(rtp) || rtp <:< lift(ltp))
}

/** Check that equality tests between types `ltp` and `rtp` make sense */
def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit =
if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) {
def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit =
val ltp = left.tpe.widen
if !ctx.isAfterTyper && !assumedCanEqual(Some(left), ltp, rtp) then
val res = implicitArgTree(defn.CanEqualClass.typeRef.appliedTo(ltp, rtp), span)
implicits.println(i"CanEqual witness found for $ltp / $rtp: $res: ${res.tpe}")
}

object hasSkolem extends TreeAccumulator[Boolean]:
def apply(x: Boolean, tree: Tree)(using Context): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking

override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult =
Implicits.NoMatchingImplicitsFailure
override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = ()
override def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = ()

override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
* one of `tp1`, `tp2` has a reflexive `CanEqual` instance.
*/
def validEqAnyArgs(tp1: Type, tp2: Type)(using Context) =
typer.assumedCanEqual(tp1, tp2)
typer.assumedCanEqual(None, tp1, tp2)
|| withMode(Mode.StrictEquality) {
!hasEq(tp1) && !hasEq(tp2)
}
Expand Down
45 changes: 45 additions & 0 deletions compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package dotty.tools.dotc.typer

import dotty.tools.DottyTest
import dotty.tools.dotc.core.Contexts.*

import org.junit.Test
import org.junit.Assert.fail

class SIP67Tests extends DottyTest:

private def checkNoErrors(source: String): Unit =
val ctx = checkCompile("typer", source)((_, _) => ())
if ctx.reporter.hasErrors then
fail("Unexpected compilation errors were reported")

@Test
def sip67test1: Unit =
checkNoErrors:
"""
import scala.language.strictEquality
import scala.language.experimental.strictEqualityPatternMatching
enum Foo:
case Bar

val _ =
(??? : Foo) match
case Foo.Bar =>
"""
@Test
def sip67test2: Unit =
checkNoErrors:
"""
import scala.language.strictEquality
import scala.language.experimental.strictEqualityPatternMatching

sealed trait Foo

object Foo:
case object Bar extends Foo

val _ =
(??? : Foo) match
case Foo.Bar =>
"""

7 changes: 7 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ object language:
@compileTimeOnly("`erasedDefinitions` can only be used at compile time in import statements")
object erasedDefinitions

/** Experimental support for relaxed CanEqual checks for ADT pattern matching
*
* @see [[https://github.com/scala/improvement-proposals/pull/97]]
*/
@compileTimeOnly("`strictEqualityPatternMatching` can only be used at compile time in import statements")
object strictEqualityPatternMatching

/** Experimental support for using indentation for arguments
*/
@compileTimeOnly("`fewerBraces` can only be used at compile time in import statements")
Expand Down
Loading