Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ object Feature:

val dependent = experimental("dependent")
val erasedDefinitions = experimental("erasedDefinitions")
val strictEqualityPatternMatching = experimental("strictEqualityPatternMatching")
val symbolLiterals = deprecated("symbolLiterals")
val saferExceptions = experimental("saferExceptions")
val pureFunctions = experimental("pureFunctions")
Expand Down Expand Up @@ -58,6 +59,7 @@ object Feature:
(scala2macros, "Allow Scala 2 macros"),
(dependent, "Allow dependent method types"),
(erasedDefinitions, "Allow erased definitions"),
(strictEqualityPatternMatching, "relaxed CanEqual checks for ADT pattern matching"),
(symbolLiterals, "Allow symbol literals"),
(saferExceptions, "Enable safer exceptions"),
(pureFunctions, "Enable pure functions for capture checking"),
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ trait Applications extends Compatibility {
case Apply(fn @ Select(left, _), right :: Nil) if fn.hasType =>
val op = fn.symbol
if (op == defn.Any_== || op == defn.Any_!=)
checkCanEqual(left.tpe.widen, right.tpe.widen, app.span)
checkCanEqual(left, right.tpe.widen, app.span)
case _ =>
}
app
Expand Down
18 changes: 13 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ object Implicits:
def strictEquality(using Context): Boolean =
ctx.mode.is(Mode.StrictEquality) || Feature.enabled(nme.strictEquality)

def strictEqualityPatternMatching(using Context): Boolean =
Feature.enabled(Feature.strictEqualityPatternMatching)


/** A common base class of contextual implicits and of-type implicits which
* represents a set of references to implicit definitions.
Expand Down Expand Up @@ -1039,7 +1042,7 @@ trait Implicits:
* - if one of T, U is a subtype of the lifted version of the other,
* unless strict equality is set.
*/
def assumedCanEqual(ltp: Type, rtp: Type)(using Context) = {
def assumedCanEqual(leftTreeOption: Option[Tree], ltp: Type, rtp: Type)(using Context): Boolean = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually use EmptyTree for missing trees instead of an option type. Also, leftTree should be the last argument. Suggestion:

def assumedCanEqual(ltp: Type, rtp: Type, leftTree: Tree = EmptyTree)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks 👍🏻

// Map all non-opaque abstract types to their upper bound.
// This is done to check whether such types might plausibly be comparable to each other.
val lift = new TypeMap {
Expand All @@ -1062,15 +1065,20 @@ trait Implicits:

ltp.isError
|| rtp.isError
|| !strictEquality && (ltp <:< lift(rtp) || rtp <:< lift(ltp))
|| locally:
if strictEquality then
strictEqualityPatternMatching && leftTreeOption.exists: leftTree =>
ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.flags is redundant, should be omitted. Suggestion:

          if strictEqualityPatternMatching 
            && (leftTree.symbol.isAllOf(Flags.EnumValue) || leftTree.symbol.isAllOf(Flags.Module | Flags.Case))
          then ltp <:< lift(rtp)
          else false

More pedestrian but much clearer and more efficient.

Copy link
Contributor Author

@mberndt123 mberndt123 Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've removed the .flags. Re replacing && with if: that would lead to nested if expressions, and I think a simple && is easier to read than that. I'd also like to understand why if would be more efficient given that && implements short-circuiting.

I'd prefer leaving it like this, but will change it if you insist.

else
(ltp <:< lift(rtp) || rtp <:< lift(ltp))
}

/** Check that equality tests between types `ltp` and `rtp` make sense */
def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit =
if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) {
def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ltp is not part of the API anymore, so the doc comment needs to be updated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've updated the comment

val ltp = left.tpe.widen
if !ctx.isAfterTyper && !assumedCanEqual(Some(left), ltp, rtp) then
val res = implicitArgTree(defn.CanEqualClass.typeRef.appliedTo(ltp, rtp), span)
implicits.println(i"CanEqual witness found for $ltp / $rtp: $res: ${res.tpe}")
}

object hasSkolem extends TreeAccumulator[Boolean]:
def apply(x: Boolean, tree: Tree)(using Context): Boolean =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/ReTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking

override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult =
Implicits.NoMatchingImplicitsFailure
override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = ()
override def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = ()

override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Synthesizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
* one of `tp1`, `tp2` has a reflexive `CanEqual` instance.
*/
def validEqAnyArgs(tp1: Type, tp2: Type)(using Context) =
typer.assumedCanEqual(tp1, tp2)
typer.assumedCanEqual(None, tp1, tp2)
|| withMode(Mode.StrictEquality) {
!hasEq(tp1) && !hasEq(tp2)
}
Expand Down
44 changes: 44 additions & 0 deletions compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dotty.tools.dotc.typer

import dotty.tools.DottyTest
import dotty.tools.dotc.core.Contexts.*

import org.junit.Test
import org.junit.Assert.fail

class SIP67Tests extends DottyTest:

private def checkNoErrors(source: String): Unit =
val ctx = checkCompile("typer", source)((_, _) => ())
if ctx.reporter.hasErrors then
fail("Unexpected compilation errors were reported")

@Test
def sip67test1: Unit =
checkNoErrors:
"""
import scala.language.strictEquality
import scala.language.experimental.strictEqualityPatternMatching
enum Foo:
case Bar

val _ =
(??? : Foo) match
case Foo.Bar =>
"""
@Test
def sip67test2: Unit =
checkNoErrors:
"""
import scala.language.strictEquality
import scala.language.experimental.strictEqualityPatternMatching

sealed trait Foo

object Foo:
case object Bar extends Foo

val _ =
(??? : Foo) match
case Foo.Bar =>
"""
7 changes: 7 additions & 0 deletions library/src/scala/runtime/stdLibPatches/language.scala
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this object in scala.language.experimental too (Not just in the stdLibPatches).

Copy link
Contributor Author

@mberndt123 mberndt123 Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out Hamza, I've now added it.

Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ object language:
@compileTimeOnly("`erasedDefinitions` can only be used at compile time in import statements")
object erasedDefinitions

/** Experimental support for relaxed CanEqual checks for ADT pattern matching
*
* @see [[https://github.com/scala/improvement-proposals/pull/97]]
*/
@compileTimeOnly("`strictEqualityPatternMatching` can only be used at compile time in import statements")
object strictEqualityPatternMatching

/** Experimental support for using indentation for arguments
*/
@compileTimeOnly("`fewerBraces` can only be used at compile time in import statements")
Expand Down
Loading