Skip to content

Commit ba00c03

Browse files
committed
Implement spreads in the middle of pattern sequences
1 parent 20ae6c1 commit ba00c03

File tree

9 files changed

+120
-26
lines changed

9 files changed

+120
-26
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,12 @@ class Definitions {
570570
@tu lazy val Seq_apply : Symbol = SeqClass.requiredMethod(nme.apply)
571571
@tu lazy val Seq_head : Symbol = SeqClass.requiredMethod(nme.head)
572572
@tu lazy val Seq_drop : Symbol = SeqClass.requiredMethod(nme.drop)
573+
@tu lazy val Seq_dropRight : Symbol = SeqClass.requiredMethod(nme.dropRight)
574+
@tu lazy val Seq_takeRight : Symbol = SeqClass.requiredMethod(nme.takeRight)
573575
@tu lazy val Seq_lengthCompare: Symbol = SeqClass.requiredMethod(nme.lengthCompare, List(IntType))
574576
@tu lazy val Seq_length : Symbol = SeqClass.requiredMethod(nme.length)
575577
@tu lazy val Seq_toSeq : Symbol = SeqClass.requiredMethod(nme.toSeq)
576578

577-
578579
@tu lazy val StringOps: Symbol = requiredClass("scala.collection.StringOps")
579580
@tu lazy val StringOps_format: Symbol = StringOps.requiredMethod(nme.format)
580581

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ object StdNames {
470470
val doubleHash: N = "doubleHash"
471471
val dotty: N = "dotty"
472472
val drop: N = "drop"
473+
val dropRight: N = "dropRight"
473474
val dynamics: N = "dynamics"
474475
val elem: N = "elem"
475476
val elems: N = "elems"
@@ -802,6 +803,7 @@ object StdNames {
802803
val takeModulo: N = "takeModulo"
803804
val takeNot: N = "takeNot"
804805
val takeOr: N = "takeOr"
806+
val takeRight: N = "takeRight"
805807
val takeXor: N = "takeXor"
806808
val testEqual: N = "testEqual"
807809
val testGreaterOrEqualThan: N = "testGreaterOrEqualThan"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3352,7 +3352,9 @@ object Parsers {
33523352
if (in.token == RPAREN) Nil else patterns(location)
33533353

33543354
/** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’
3355-
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’
3355+
* | ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns] ‘)’
3356+
*
3357+
* -- It is checked in Typer that there are no repeated PatVar arguments.
33563358
*/
33573359
def argumentPatterns(): List[Tree] =
33583360
inParensWithCommas(patternsOpt(Location.InPatternArgs))

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -299,30 +299,57 @@ object PatternMatcher {
299299
}
300300

301301
/** Plan for matching the sequence in `getResult` against sequence elements
302-
* and a possible last varargs argument `args`.
302+
* `args`. Sequence elements may contain a varargs argument.
303+
* Example:
304+
*
305+
* lst match case Seq(1, xs*, 2, 3) => ...
306+
*
307+
* generates code which is equivalent to:
308+
*
309+
* if lst != null then
310+
* if lst.lengthCompare >= 1 then
311+
* if lst(0) == 1 then
312+
* val x1 = lst.drop(1)
313+
* val xs = x1.dropRight(2)
314+
* val x2 = lst.takeRight(2)
315+
* if x2.lengthCompare >= 2 then
316+
* if x2(0) == 2 then
317+
* if x2(1) == 3 then
318+
* return[matchResult] ...
303319
*/
304-
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
305-
case Some(VarArgPattern(arg)) =>
306-
val matchRemaining =
307-
if (args.length == 1) {
308-
val toSeq = ref(getResult)
309-
.select(defn.Seq_toSeq.matchingMember(getResult.info))
310-
letAbstract(toSeq) { toSeqResult =>
311-
patternPlan(toSeqResult, arg, onSuccess)
312-
}
313-
}
314-
else {
315-
val dropped = ref(getResult)
316-
.select(defn.Seq_drop.matchingMember(getResult.info))
317-
.appliedTo(Literal(Constant(args.length - 1)))
318-
letAbstract(dropped) { droppedResult =>
319-
patternPlan(droppedResult, arg, onSuccess)
320-
}
321-
}
322-
matchElemsPlan(getResult, args.init, exact = false, matchRemaining)
323-
case _ =>
324-
matchElemsPlan(getResult, args, exact = true, onSuccess)
325-
}
320+
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan =
321+
val (leading, varargAndRest) = args.span:
322+
case VarArgPattern(_) => false
323+
case _ => true
324+
varargAndRest match
325+
case VarArgPattern(arg) :: trailing =>
326+
val remaining =
327+
if leading.isEmpty then
328+
ref(getResult)
329+
.select(defn.Seq_toSeq.matchingMember(getResult.info))
330+
else
331+
ref(getResult)
332+
.select(defn.Seq_drop.matchingMember(getResult.info))
333+
.appliedTo(Literal(Constant(leading.length)))
334+
val matchRemaining =
335+
letAbstract(remaining): remainingResult =>
336+
if trailing.isEmpty then
337+
patternPlan(remainingResult, arg, onSuccess)
338+
else
339+
val seq = ref(remainingResult)
340+
.select(defn.Seq_dropRight.matchingMember(remainingResult.info))
341+
.appliedTo(Literal(Constant(trailing.length)))
342+
letAbstract(seq): seqResult =>
343+
val rest = ref(remainingResult)
344+
.select(defn.Seq_takeRight.matchingMember(remainingResult.info))
345+
.appliedTo(Literal(Constant(trailing.length)))
346+
val matchTrailing =
347+
letAbstract(rest): trailingResult =>
348+
matchElemsPlan(trailingResult, trailing, exact = true, onSuccess)
349+
patternPlan(seqResult, arg, matchTrailing)
350+
matchElemsPlan(getResult, leading, exact = false, matchRemaining)
351+
case _ =>
352+
matchElemsPlan(getResult, args, exact = true, onSuccess)
326353

327354
/** Plan for matching the sequence in `getResult`
328355
*

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,11 @@ object Applications {
303303
report.error(UnapplyInvalidNumberOfArguments(qual, argTypes), pos)
304304
argTypes.take(args.length) ++
305305
List.fill(argTypes.length - args.length)(WildcardType)
306+
307+
val varArgs = alignedArgs.filter(untpd.isWildcardStarArg)
308+
if varArgs.length >= 2 then
309+
report.error(em"Ony one spread pattern allowed in sequence", varArgs(1).srcPos)
310+
306311
alignedArgs.lazyZip(alignedArgTypes).map(typer.typed(_, _))
307312
.showing(i"unapply patterns = $result", unapp)
308313

docs/_docs/internals/syntax.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ Patterns ::= Pattern {‘,’ Pattern}
365365
NamedPattern ::= id '=' Pattern
366366
367367
ArgumentPatterns ::= ‘(’ [Patterns] ‘)’ Apply(fn, pats)
368-
| ‘(’ [Patterns ‘,’] PatVar ‘*’ ‘)’
368+
| ‘(’ [Patterns ‘,’] PatVar ‘*’ [‘,’ Patterns]‘)’
369369
```
370370

371371
### Type and Value Parameters

tests/neg/spread-patterns.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import language.experimental.multiSpreads
2+
3+
def use[T](xs: T*) = println(xs)
4+
5+
def useInt(xs: Int*) = ???
6+
7+
@main def Test() =
8+
val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6)
9+
val xs = List(1, 2, 3, 4, 5, 6)
10+
11+
xs match
12+
case List(1, 2, xs*, ys*, 6) => println(xs) // error
13+
14+
15+

tests/run/spread-patterns.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
List(3, 4, 5)
2+
ArraySeq(4, 5, 6)
3+
ArraySeq(1, 2, 3)
4+
ArraySeq(3, 4, 5, 6)

tests/run/spread-patterns.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import language.experimental.multiSpreads
2+
3+
def use[T](xs: T*) = println(xs)
4+
5+
def useInt(xs: Int*) = ???
6+
7+
@main def Test() =
8+
val arr: Array[Int] = Array(1, 2, 3, 4, 5, 6)
9+
val lst = List(1, 2, 3, 4, 5, 6)
10+
11+
lst match
12+
case List(1, xs*, 2, 3, 4, 5, 6) =>
13+
assert(xs.isEmpty)
14+
15+
lst match
16+
case List(1, 2, xs*, 6) => println(xs)
17+
18+
arr match
19+
case Array(1, 2, xs*, 7) => assert(false)
20+
case Array(1, 2, 3, xs*) => println(xs)
21+
22+
arr match
23+
case Array(xs*, 1, 2) => assert(false)
24+
case Array(xs*, 4, 5, 6) => println(xs)
25+
26+
arr match
27+
case Array(1, 2, xs*) => println(xs)
28+
29+
lst match
30+
case List(1, 2, 3, 4, 5, 6, xs*) => assert(xs.isEmpty)
31+
32+
lst match
33+
case Seq(xs*, 1, 2, 3, 4, 5, 6) => assert(xs.isEmpty)
34+
35+
36+
37+
38+

0 commit comments

Comments
 (0)