From 5778bea2f37a08eb5a8fdea4592164a92b50be2b Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Sun, 24 Aug 2025 23:26:54 +0200 Subject: [PATCH 1/6] Implement SIP-67: strictEquality pattern matching (Github issue #22732) --- .../src/dotty/tools/dotc/typer/Typer.scala | 11 ++-- .../dotty/tools/dotc/typer/SIP67Tests.scala | 52 +++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index def6fac0556e..802a76f5eb31 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -5052,11 +5052,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer Linter.warnOnImplausiblePattern(tree, pt) - val cmp = - untpd.Apply( - untpd.Select(untpd.TypedSplice(tree), nme.EQ), - untpd.TypedSplice(dummyTreeOfType(pt))) - typedExpr(cmp, defn.BooleanType) + if ! (tree.tpe <:< pt && (tree.symbol.flags.isAllOf(Flags.EnumValue) || (tree.symbol.flags.isAllOf(Flags.Module | Flags.Case)))) then + val cmp = + untpd.Apply( + untpd.Select(untpd.TypedSplice(tree), nme.EQ), + untpd.TypedSplice(dummyTreeOfType(pt))) + typedExpr(cmp, defn.BooleanType) case _ => private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol, isUnitExpr: Boolean = false)(using Context): Unit = diff --git a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala new file mode 100644 index 000000000000..07ad5c5fc423 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala @@ -0,0 +1,52 @@ +// filepath: /home/mberndt/scala3/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala +package dotty.tools.dotc.typer + +import dotty.tools.DottyTest +import dotty.tools.dotc.core.Contexts.* + +import org.junit.Test +import org.junit.Assert.fail + +class SIP67Tests extends DottyTest { + + + @Test + def sip67test1: Unit = { + val source = """ + import scala.language.strictEquality + enum Foo { + case Bar + } + + val _ = + (??? : Foo) match { + case Foo.Bar => + } + """ + val ctx = checkCompile("typer", source) { (_, ctx) => } + if ctx.reporter.hasErrors then + fail("Unexpected compilation errors were reported") + } + + @Test + def sip67test2: Unit = { + val source = """ + import scala.language.strictEquality + + sealed trait Foo + + object Foo { + case object Bar extends Foo + } + + val _ = + (??? : Foo) match { + case Foo.Bar => + } + """ + val ctx = checkCompile("typer", source) { (_, ctx) => } + if ctx.reporter.hasErrors then + fail("Unexpected compilation errors were reported") + } + +} \ No newline at end of file From 0a565d0caa6552646330e403ac386bbe568da908 Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Tue, 26 Aug 2025 22:56:35 +0200 Subject: [PATCH 2/6] use braceless syntax in SIP67Tests --- .../dotty/tools/dotc/typer/SIP67Tests.scala | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala index 07ad5c5fc423..0545b40f0a5a 100644 --- a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala +++ b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala @@ -1,4 +1,3 @@ -// filepath: /home/mberndt/scala3/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala package dotty.tools.dotc.typer import dotty.tools.DottyTest @@ -7,46 +6,38 @@ import dotty.tools.dotc.core.Contexts.* import org.junit.Test import org.junit.Assert.fail -class SIP67Tests extends DottyTest { - +class SIP67Tests extends DottyTest: @Test - def sip67test1: Unit = { + def sip67test1: Unit = val source = """ import scala.language.strictEquality - enum Foo { + enum Foo: case Bar - } val _ = - (??? : Foo) match { + (??? : Foo) match case Foo.Bar => - } """ - val ctx = checkCompile("typer", source) { (_, ctx) => } + val ctx = checkCompile("typer", source)((_, ctx) => ()) + if ctx.reporter.hasErrors then fail("Unexpected compilation errors were reported") - } @Test - def sip67test2: Unit = { + def sip67test2: Unit = val source = """ import scala.language.strictEquality sealed trait Foo - object Foo { + object Foo: case object Bar extends Foo - } val _ = - (??? : Foo) match { + (??? : Foo) match case Foo.Bar => - } """ - val ctx = checkCompile("typer", source) { (_, ctx) => } + val ctx = checkCompile("typer", source)((_, ctx) => ()) if ctx.reporter.hasErrors then fail("Unexpected compilation errors were reported") - } - -} \ No newline at end of file From d433ebed9c6812424c31c9d17919325e361b4126 Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Sun, 31 Aug 2025 22:35:58 +0200 Subject: [PATCH 3/6] change SIP-67 implementation based on @odersky's review --- .../src/dotty/tools/dotc/typer/Applications.scala | 2 +- .../src/dotty/tools/dotc/typer/Implicits.scala | 15 ++++++++++----- compiler/src/dotty/tools/dotc/typer/ReTyper.scala | 2 +- .../src/dotty/tools/dotc/typer/Synthesizer.scala | 2 +- compiler/src/dotty/tools/dotc/typer/Typer.scala | 11 +++++------ 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 290e061772e4..322419e3f6cf 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -1338,7 +1338,7 @@ trait Applications extends Compatibility { case Apply(fn @ Select(left, _), right :: Nil) if fn.hasType => val op = fn.symbol if (op == defn.Any_== || op == defn.Any_!=) - checkCanEqual(left.tpe.widen, right.tpe.widen, app.span) + checkCanEqual(left, right.tpe.widen, app.span) case _ => } app diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index fbdf6ab80bbf..172f7d6c5920 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -1039,7 +1039,7 @@ trait Implicits: * - if one of T, U is a subtype of the lifted version of the other, * unless strict equality is set. */ - def assumedCanEqual(ltp: Type, rtp: Type)(using Context) = { + def assumedCanEqual(leftTreeOption: Option[Tree], ltp: Type, rtp: Type)(using Context): Boolean = { // Map all non-opaque abstract types to their upper bound. // This is done to check whether such types might plausibly be comparable to each other. val lift = new TypeMap { @@ -1062,15 +1062,20 @@ trait Implicits: ltp.isError || rtp.isError - || !strictEquality && (ltp <:< lift(rtp) || rtp <:< lift(ltp)) + || locally: + if strictEquality then + leftTreeOption.exists: leftTree => + ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || (leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case))) + else + (ltp <:< lift(rtp) || rtp <:< lift(ltp)) } /** Check that equality tests between types `ltp` and `rtp` make sense */ - def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = - if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) { + def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = + val ltp = left.tpe.widen + if !ctx.isAfterTyper && !assumedCanEqual(Some(left), ltp, rtp) then val res = implicitArgTree(defn.CanEqualClass.typeRef.appliedTo(ltp, rtp), span) implicits.println(i"CanEqual witness found for $ltp / $rtp: $res: ${res.tpe}") - } object hasSkolem extends TreeAccumulator[Boolean]: def apply(x: Boolean, tree: Tree)(using Context): Boolean = diff --git a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala index ed8919661860..8400639706d6 100644 --- a/compiler/src/dotty/tools/dotc/typer/ReTyper.scala +++ b/compiler/src/dotty/tools/dotc/typer/ReTyper.scala @@ -175,7 +175,7 @@ class ReTyper(nestingLevel: Int = 0) extends Typer(nestingLevel) with ReChecking override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult = Implicits.NoMatchingImplicitsFailure - override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = () + override def checkCanEqual(left: Tree, rtp: Type, span: Span)(using Context): Unit = () override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index 3b114de6a05c..57c179a77f74 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -183,7 +183,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): * one of `tp1`, `tp2` has a reflexive `CanEqual` instance. */ def validEqAnyArgs(tp1: Type, tp2: Type)(using Context) = - typer.assumedCanEqual(tp1, tp2) + typer.assumedCanEqual(None, tp1, tp2) || withMode(Mode.StrictEquality) { !hasEq(tp1) && !hasEq(tp2) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 802a76f5eb31..def6fac0556e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -5052,12 +5052,11 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer Linter.warnOnImplausiblePattern(tree, pt) - if ! (tree.tpe <:< pt && (tree.symbol.flags.isAllOf(Flags.EnumValue) || (tree.symbol.flags.isAllOf(Flags.Module | Flags.Case)))) then - val cmp = - untpd.Apply( - untpd.Select(untpd.TypedSplice(tree), nme.EQ), - untpd.TypedSplice(dummyTreeOfType(pt))) - typedExpr(cmp, defn.BooleanType) + val cmp = + untpd.Apply( + untpd.Select(untpd.TypedSplice(tree), nme.EQ), + untpd.TypedSplice(dummyTreeOfType(pt))) + typedExpr(cmp, defn.BooleanType) case _ => private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol, isUnitExpr: Boolean = false)(using Context): Unit = From 4f7a9a3a2f18a05ad1f5ba3dad0f297bfc29b949 Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Sun, 31 Aug 2025 22:54:40 +0200 Subject: [PATCH 4/6] hide SIP-67 behaviour behind experimental feature flag --- compiler/src/dotty/tools/dotc/config/Feature.scala | 2 ++ compiler/src/dotty/tools/dotc/typer/Implicits.scala | 5 ++++- compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala | 2 ++ library/src/scala/runtime/stdLibPatches/language.scala | 7 +++++++ 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/config/Feature.scala b/compiler/src/dotty/tools/dotc/config/Feature.scala index 70a77c9560b2..c5698178f628 100644 --- a/compiler/src/dotty/tools/dotc/config/Feature.scala +++ b/compiler/src/dotty/tools/dotc/config/Feature.scala @@ -28,6 +28,7 @@ object Feature: val dependent = experimental("dependent") val erasedDefinitions = experimental("erasedDefinitions") + val strictEqualityPatternMatching = experimental("strictEqualityPatternMatching") val symbolLiterals = deprecated("symbolLiterals") val saferExceptions = experimental("saferExceptions") val pureFunctions = experimental("pureFunctions") @@ -58,6 +59,7 @@ object Feature: (scala2macros, "Allow Scala 2 macros"), (dependent, "Allow dependent method types"), (erasedDefinitions, "Allow erased definitions"), + (strictEqualityPatternMatching, "relaxed CanEqual checks for ADT pattern matching"), (symbolLiterals, "Allow symbol literals"), (saferExceptions, "Enable safer exceptions"), (pureFunctions, "Enable pure functions for capture checking"), diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index 172f7d6c5920..be5bc5cd2557 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -84,6 +84,9 @@ object Implicits: def strictEquality(using Context): Boolean = ctx.mode.is(Mode.StrictEquality) || Feature.enabled(nme.strictEquality) + def strictEqualityPatternMatching(using Context): Boolean = + Feature.enabled(Feature.strictEqualityPatternMatching) + /** A common base class of contextual implicits and of-type implicits which * represents a set of references to implicit definitions. @@ -1064,7 +1067,7 @@ trait Implicits: || rtp.isError || locally: if strictEquality then - leftTreeOption.exists: leftTree => + strictEqualityPatternMatching && leftTreeOption.exists: leftTree => ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || (leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case))) else (ltp <:< lift(rtp) || rtp <:< lift(ltp)) diff --git a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala index 0545b40f0a5a..5b47718fb5cb 100644 --- a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala +++ b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala @@ -12,6 +12,7 @@ class SIP67Tests extends DottyTest: def sip67test1: Unit = val source = """ import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching enum Foo: case Bar @@ -28,6 +29,7 @@ class SIP67Tests extends DottyTest: def sip67test2: Unit = val source = """ import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching sealed trait Foo diff --git a/library/src/scala/runtime/stdLibPatches/language.scala b/library/src/scala/runtime/stdLibPatches/language.scala index 9d38ea4371ff..5d13ea17f2b7 100644 --- a/library/src/scala/runtime/stdLibPatches/language.scala +++ b/library/src/scala/runtime/stdLibPatches/language.scala @@ -50,6 +50,13 @@ object language: @compileTimeOnly("`erasedDefinitions` can only be used at compile time in import statements") object erasedDefinitions + /** Experimental support for relaxed CanEqual checks for ADT pattern matching + * + * @see [[https://github.com/scala/improvement-proposals/pull/97]] + */ + @compileTimeOnly("`strictEqualityPatternMatching` can only be used at compile time in import statements") + object strictEqualityPatternMatching + /** Experimental support for using indentation for arguments */ @compileTimeOnly("`fewerBraces` can only be used at compile time in import statements") From 56f11c2ed461f97ff2debfd40aa0fdc72a0c4300 Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Mon, 1 Sep 2025 01:00:47 +0200 Subject: [PATCH 5/6] remove excess parens --- compiler/src/dotty/tools/dotc/typer/Implicits.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/typer/Implicits.scala b/compiler/src/dotty/tools/dotc/typer/Implicits.scala index be5bc5cd2557..4a6e92bb9849 100644 --- a/compiler/src/dotty/tools/dotc/typer/Implicits.scala +++ b/compiler/src/dotty/tools/dotc/typer/Implicits.scala @@ -1068,7 +1068,7 @@ trait Implicits: || locally: if strictEquality then strictEqualityPatternMatching && leftTreeOption.exists: leftTree => - ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || (leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case))) + ltp <:< lift(rtp) && (leftTree.symbol.flags.isAllOf(Flags.EnumValue) || leftTree.symbol.flags.isAllOf(Flags.Module | Flags.Case)) else (ltp <:< lift(rtp) || rtp <:< lift(ltp)) } From 6570fb2b3f47fd42d47c41b78510c9d169deeda2 Mon Sep 17 00:00:00 2001 From: Matthias Berndt Date: Mon, 1 Sep 2025 17:31:50 +0200 Subject: [PATCH 6/6] minor refactoring of SIP67Tests --- .../dotty/tools/dotc/typer/SIP67Tests.scala | 56 +++++++++---------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala index 5b47718fb5cb..8c6173a886d0 100644 --- a/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala +++ b/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala @@ -8,38 +8,38 @@ import org.junit.Assert.fail class SIP67Tests extends DottyTest: - @Test - def sip67test1: Unit = - val source = """ - import scala.language.strictEquality - import scala.language.experimental.strictEqualityPatternMatching - enum Foo: - case Bar - - val _ = - (??? : Foo) match - case Foo.Bar => - """ - val ctx = checkCompile("typer", source)((_, ctx) => ()) - + private def checkNoErrors(source: String): Unit = + val ctx = checkCompile("typer", source)((_, _) => ()) if ctx.reporter.hasErrors then fail("Unexpected compilation errors were reported") - + + @Test + def sip67test1: Unit = + checkNoErrors: + """ + import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching + enum Foo: + case Bar + + val _ = + (??? : Foo) match + case Foo.Bar => + """ @Test def sip67test2: Unit = - val source = """ - import scala.language.strictEquality - import scala.language.experimental.strictEqualityPatternMatching + checkNoErrors: + """ + import scala.language.strictEquality + import scala.language.experimental.strictEqualityPatternMatching - sealed trait Foo + sealed trait Foo - object Foo: - case object Bar extends Foo + object Foo: + case object Bar extends Foo - val _ = - (??? : Foo) match - case Foo.Bar => - """ - val ctx = checkCompile("typer", source)((_, ctx) => ()) - if ctx.reporter.hasErrors then - fail("Unexpected compilation errors were reported") + val _ = + (??? : Foo) match + case Foo.Bar => + """ + \ No newline at end of file