Skip to content

Commit 855adf8

Browse files
committed
Fix #3248: support product-seq patterns
1 parent 0558618 commit 855adf8

File tree

3 files changed

+115
-55
lines changed

3 files changed

+115
-55
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import collection.mutable
88
import Symbols._, Contexts._, Types._, StdNames._, NameOps._
99
import ast.Trees._
1010
import util.Spans._
11-
import typer.Applications.{isProductMatch, isGetMatch, productSelectors}
11+
import typer.Applications.{isProductMatch, isGetMatch, isProductSeqMatch, productSelectors, productArity}
1212
import SymUtils._
1313
import Flags._, Constants._
1414
import Decorators._
@@ -262,6 +262,8 @@ object PatternMatcher {
262262

263263
/** Plan for matching the sequence in `getResult` against sequence elements
264264
* and a possible last varargs argument `args`.
265+
*
266+
* `getResult` could also be a product, where the last element is a sequence of elements.
265267
*/
266268
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
267269
case Some(VarArgPattern(arg)) =>
@@ -286,6 +288,22 @@ object PatternMatcher {
286288
matchElemsPlan(getResult, args, exact = true, onSuccess)
287289
}
288290

291+
/** Plan for matching the sequence in `getResult` against sequence elements
292+
* and a possible last varargs argument `args`.
293+
*
294+
* `getResult` is a product, where the last element is a sequence of elements.
295+
*/
296+
def unapplyProductSeqPlan(getResult: Symbol, args: List[Tree], arity: Int): Plan = {
297+
assert(arity <= args.size + 1)
298+
val selectors = productSelectors(getResult.info).map(ref(getResult).select(_))
299+
300+
val matchSeq =
301+
letAbstract(selectors.last) { seqResult =>
302+
unapplySeqPlan(seqResult, args.drop(arity - 1))
303+
}
304+
matchArgsPlan(selectors.take(arity - 1), args.take(arity - 1), matchSeq)
305+
}
306+
289307
/** Plan for matching the result of an unapply against argument patterns `args` */
290308
def unapplyPlan(unapp: Tree, args: List[Tree]): Plan = {
291309
def caseClass = unapp.symbol.owner.linkedClass
@@ -306,18 +324,34 @@ object PatternMatcher {
306324
.map(ref(unappResult).select(_))
307325
matchArgsPlan(selectors, args, onSuccess)
308326
}
327+
else if (isProductSeqMatch(unapp.tpe.widen, args.length, unapp.sourcePos) && !isUnapplySeq) {
328+
val arity = productArity(unapp.tpe.widen, unapp.sourcePos)
329+
unapplyProductSeqPlan(unappResult, args, arity)
330+
}
309331
else {
310332
assert(isGetMatch(unapp.tpe))
311333
val argsPlan = {
312334
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
335+
val arity = productArity(get.tpe, unapp.sourcePos)
313336
if (isUnapplySeq)
314-
letAbstract(get)(unapplySeqPlan(_, args))
337+
letAbstract(get) { getResult =>
338+
if (arity > 0) unapplyProductSeqPlan(getResult, args, arity)
339+
else unapplySeqPlan(getResult, args)
340+
}
315341
else
316342
letAbstract(get) { getResult =>
317-
val selectors =
318-
if (args.tail.isEmpty) ref(getResult) :: Nil
319-
else productSelectors(get.tpe).map(ref(getResult).select(_))
320-
matchArgsPlan(selectors, args, onSuccess)
343+
if (args.tail.isEmpty) // Single pattern takes precedence
344+
matchArgsPlan(ref(getResult) :: Nil, args, onSuccess)
345+
else if (isProductMatch(get.tpe, args.length, unapp.sourcePos)) {
346+
val sels = productSelectors(get.tpe).map(ref(getResult).select(_))
347+
matchArgsPlan(sels, args, onSuccess)
348+
}
349+
else if (isProductSeqMatch(get.tpe, args.length, unapp.sourcePos))
350+
unapplyProductSeqPlan(getResult, args, arity)
351+
else { // name-based
352+
val sels = productSelectors(get.tpe).map(ref(getResult).select(_))
353+
matchArgsPlan(sels, args, onSuccess)
354+
}
321355
}
322356
}
323357
TestPlan(NonEmptyTest, unappResult, unapp.span, argsPlan)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,22 @@ object Applications {
4747

4848
/** Does `tp` fit the "product match" conditions as an unapply result type
4949
* for a pattern with `numArgs` subpatterns?
50-
* This is the case of `tp` has members `_1` to `_N` where `N == numArgs`.
50+
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
5151
*/
5252
def isProductMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean =
5353
numArgs > 0 && productArity(tp, errorPos) == numArgs
5454

55+
/** Does `tp` fit the "product-seq match" conditions as an unapply result type
56+
* for a pattern with `numArgs` subpatterns?
57+
* This is the case if (1) `tp` has members `_1` to `_N` where `N <= numArgs + 1`.
58+
* (2) `tp._N` conforms to Seq match
59+
*/
60+
def isProductSeqMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean = {
61+
val arity = productArity(tp, errorPos)
62+
arity > 0 && arity <= numArgs + 1 &&
63+
unapplySeqTypeElemTp(productSelectorTypes(tp, errorPos).last).exists
64+
}
65+
5566
/** Does `tp` fit the "get match" conditions as an unapply result type?
5667
* This is the case of `tp` has a `get` member as well as a
5768
* parameterless `isEmpty` member of result type `Boolean`.
@@ -60,6 +71,39 @@ object Applications {
6071
extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) &&
6172
extractorMemberType(tp, nme.get, errorPos).exists
6273

74+
/** If `getType` is of the form:
75+
* ```
76+
* {
77+
* def lengthCompare(len: Int): Int // or, def length: Int
78+
* def apply(i: Int): T = a(i)
79+
* def drop(n: Int): scala.Seq[T]
80+
* def toSeq: scala.Seq[T]
81+
* }
82+
* ```
83+
* returns `T`, otherwise NoType.
84+
*/
85+
def unapplySeqTypeElemTp(getTp: Type)(implicit ctx: Context): Type = {
86+
def lengthTp = ExprType(defn.IntType)
87+
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
88+
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
89+
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
90+
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))
91+
92+
// the result type of `def apply(i: Int): T`
93+
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
94+
95+
def hasMethod(name: Name, tp: Type) =
96+
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists
97+
98+
val isValid =
99+
elemTp.exists &&
100+
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
101+
hasMethod(nme.drop, dropTp(elemTp)) &&
102+
hasMethod(nme.toSeq, toSeqTp(elemTp))
103+
104+
if (isValid) elemTp else NoType
105+
}
106+
63107
def productSelectorTypes(tp: Type, errorPos: SourcePosition)(implicit ctx: Context): List[Type] = {
64108
def tupleSelectors(n: Int, tp: Type): List[Type] = {
65109
val sel = extractorMemberType(tp, nme.selectorName(n), errorPos)
@@ -89,9 +133,17 @@ object Applications {
89133
if (args.length > 1 && !(tp.derivesFrom(defn.SeqClass))) {
90134
val sels = productSelectorTypes(tp, pos)
91135
if (sels.length == args.length) sels
136+
else if (isProductSeqMatch(tp, args.length, pos)) productSeqSelectors(tp, args, pos)
92137
else tp :: Nil
93138
} else tp :: Nil
94139

140+
def productSeqSelectors(tp: Type, args: List[untpd.Tree], pos: SourcePosition)(implicit ctx: Context): List[Type] = {
141+
val selTps = productSelectorTypes(tp, pos)
142+
val arity = selTps.length
143+
val elemTp = unapplySeqTypeElemTp(selTps.last)
144+
(0 until args.length).map(i => if (i < arity - 1) selTps(i) else elemTp).toList
145+
}
146+
95147
def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: SourcePosition)(implicit ctx: Context): List[Type] = {
96148

97149
val unapplyName = unapplyFn.symbol.name
@@ -103,43 +155,11 @@ object Applications {
103155
Nil
104156
}
105157

106-
/** If `getType` is of the form:
107-
* ```
108-
* {
109-
* def lengthCompare(len: Int): Int // or, def length: Int
110-
* def apply(i: Int): T = a(i)
111-
* def drop(n: Int): scala.Seq[T]
112-
* def toSeq: scala.Seq[T]
113-
* }
114-
* ```
115-
* returns `T`, otherwise NoType.
116-
*/
117-
def unapplySeqTypeElemTp(getTp: Type): Type = {
118-
def lengthTp = ExprType(defn.IntType)
119-
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
120-
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
121-
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
122-
def toSeqTp(elemTp: Type) = defn.SeqType.appliedTo(elemTp)
123-
124-
// the result type of `def apply(i: Int): T`
125-
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
126-
127-
def hasMethod(name: Name, tp: Type) =
128-
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists
129-
130-
val isValid =
131-
elemTp.exists &&
132-
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
133-
hasMethod(nme.drop, dropTp(elemTp)) &&
134-
hasMethod(nme.toSeq, toSeqTp(elemTp))
135-
136-
if (isValid) elemTp else NoType
137-
}
138-
139-
if (unapplyName == nme.unapplySeq) {
158+
if (unapplyName == nme.unapplySeq) { // && ctx.scala2Mode
140159
if (isGetMatch(unapplyResult, pos)) {
141160
val elemTp = unapplySeqTypeElemTp(getTp)
142161
if (elemTp.exists) args.map(Function.const(elemTp))
162+
else if (isProductSeqMatch(getTp, args.length, pos)) productSeqSelectors(getTp, args, pos)
143163
else fail
144164
}
145165
else fail
@@ -148,6 +168,8 @@ object Applications {
148168
assert(unapplyName == nme.unapply)
149169
if (isProductMatch(unapplyResult, args.length, pos))
150170
productSelectorTypes(unapplyResult, pos)
171+
else if (isProductSeqMatch(unapplyResult, args.length, pos))
172+
productSeqSelectors(unapplyResult, args, pos)
151173
else if (isGetMatch(unapplyResult, pos))
152174
getUnapplySelectors(getTp, args, pos)
153175
else if (unapplyResult.widenSingleton isRef defn.BooleanClass)

tests/run/i3248.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1-
object Test extends App {
1+
object Test {
22
class Foo(val name: String, val children: Int *)
33
object Foo {
4-
def unapply(f: Foo) = Some((f.name, f.children))
4+
def unapplySeq(f: Foo) = Some((f.name, f.children))
55
}
66

7-
def foo(f: Foo) = (f: Any) match {
8-
case Foo(name, ns: _*) => ns.length
9-
case List(ns: _*) => ns.length
7+
def foo(f: Foo) = f match {
8+
case Foo(name, ns : _*) =>
9+
assert(name == "hello")
10+
assert(ns(0) == 3)
11+
assert(ns(1) == 5)
1012
}
1113

12-
case class Bar(val children: Int*)
13-
14-
def bar(f: Any) = f match {
15-
case Bar(1, 2, 3) => 0
16-
case Bar(a, b) => a + b
17-
case Bar(ns: _*) => ns.length
14+
def bar(f: Foo) = f match {
15+
case Foo(name, x, y, ns : _*) =>
16+
assert(name == "hello")
17+
assert(x == 3)
18+
assert(y == 5)
19+
assert(ns.isEmpty)
1820
}
1921

20-
assert(bar(new Bar(1, 2, 3)) == 0)
21-
assert(bar(new Bar(3, 2, 1)) == 3)
22-
assert(foo(new Foo("name", 1, 2, 3)) == 3)
22+
def main(args: Array[String]): Unit = {
23+
val f = new Foo("hello", 3, 5)
24+
foo(f)
25+
bar(f)
26+
}
2327
}

0 commit comments

Comments
 (0)