Skip to content

Commit 5bf8e4d

Browse files
committed
Implement SIP-67: strictEquality pattern matching (Github issue #22732)
1 parent 712d5bc commit 5bf8e4d

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5052,11 +5052,12 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
50525052

50535053
Linter.warnOnImplausiblePattern(tree, pt)
50545054

5055-
val cmp =
5056-
untpd.Apply(
5057-
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
5058-
untpd.TypedSplice(dummyTreeOfType(pt)))
5059-
typedExpr(cmp, defn.BooleanType)
5055+
if ! (tree.tpe <:< pt && (tree.symbol.flags.isAllOf(Flags.EnumValue) || (tree.symbol.flags.isAllOf(Flags.Module | Flags.Case)))) then
5056+
val cmp =
5057+
untpd.Apply(
5058+
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
5059+
untpd.TypedSplice(dummyTreeOfType(pt)))
5060+
typedExpr(cmp, defn.BooleanType)
50605061
case _ =>
50615062

50625063
private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol, isUnitExpr: Boolean = false)(using Context): Unit =
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// filepath: /home/mberndt/scala3/compiler/test/dotty/tools/dotc/typer/SIP67Tests.scala
2+
package dotty.tools.dotc.typer
3+
4+
import dotty.tools.DottyTest
5+
import dotty.tools.dotc.core.Contexts.*
6+
7+
import org.junit.Test
8+
import org.junit.Assert.fail
9+
10+
class SIP67Tests extends DottyTest {
11+
12+
13+
@Test
14+
def sip67test1: Unit = {
15+
val source = """
16+
import scala.language.strictEquality
17+
enum Foo {
18+
case Bar
19+
}
20+
21+
val _ =
22+
(??? : Foo) match {
23+
case Foo.Bar =>
24+
}
25+
"""
26+
val ctx = checkCompile("typer", source) { (_, ctx) => }
27+
if ctx.reporter.hasErrors then
28+
fail("Unexpected compilation errors were reported")
29+
}
30+
31+
@Test
32+
def sip67test2: Unit = {
33+
val source = """
34+
import scala.language.strictEquality
35+
36+
sealed trait Foo
37+
38+
object Foo {
39+
case object Bar extends Foo
40+
}
41+
42+
val _ =
43+
(??? : Foo) match {
44+
case Foo.Bar =>
45+
}
46+
"""
47+
val ctx = checkCompile("typer", source) { (_, ctx) => }
48+
if ctx.reporter.hasErrors then
49+
fail("Unexpected compilation errors were reported")
50+
}
51+
52+
}

0 commit comments

Comments
 (0)