Skip to content

Commit 54cea7b

Browse files
committed
Optimize length testing for sequence matches
1 parent ba00c03 commit 54cea7b

File tree

2 files changed

+26
-15
lines changed

2 files changed

+26
-15
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ object PatternMatcher {
198198
case object NonNullTest extends Test // scrutinee ne null
199199
case object GuardTest extends Test // scrutinee
200200

201+
val noLengthTest = LengthTest(0, exact = false)
202+
201203
// ------- Generating plans from trees ------------------------
202204

203205
/** A set of variabes that are known to be not null */
@@ -291,12 +293,14 @@ object PatternMatcher {
291293
/** Plan for matching the sequence in `seqSym` against sequence elements `args`.
292294
* If `exact` is true, the sequence is not permitted to have any elements following `args`.
293295
*/
294-
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
295-
val selectors = args.indices.toList.map(idx =>
296-
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
297-
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.span,
298-
matchArgsPlan(selectors, args, onSuccess))
299-
}
296+
def matchElemsPlan(seqSym: Symbol, args: List[Tree], lengthTest: LengthTest, onSuccess: Plan) =
297+
val selectors = args.indices.toList.map: idx =>
298+
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx)))
299+
if lengthTest.len == 0 && lengthTest.exact == false then // redundant test
300+
matchArgsPlan(selectors, args, onSuccess)
301+
else
302+
TestPlan(lengthTest, seqSym, seqSym.span,
303+
matchArgsPlan(selectors, args, onSuccess))
300304

301305
/** Plan for matching the sequence in `getResult` against sequence elements
302306
* `args`. Sequence elements may contain a varargs argument.
@@ -307,15 +311,13 @@ object PatternMatcher {
307311
* generates code which is equivalent to:
308312
*
309313
* if lst != null then
310-
* if lst.lengthCompare >= 1 then
314+
* if lst.lengthCompare >= 3 then
311315
* if lst(0) == 1 then
312316
* val x1 = lst.drop(1)
313317
* val xs = x1.dropRight(2)
314318
* val x2 = lst.takeRight(2)
315-
* if x2.lengthCompare >= 2 then
316-
* if x2(0) == 2 then
317-
* if x2(1) == 3 then
318-
* return[matchResult] ...
319+
* if x2(0) == 2 && x2(1) == 3 then
320+
* return[matchResult] ...
319321
*/
320322
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan =
321323
val (leading, varargAndRest) = args.span:
@@ -345,11 +347,13 @@ object PatternMatcher {
345347
.appliedTo(Literal(Constant(trailing.length)))
346348
val matchTrailing =
347349
letAbstract(rest): trailingResult =>
348-
matchElemsPlan(trailingResult, trailing, exact = true, onSuccess)
350+
matchElemsPlan(trailingResult, trailing, noLengthTest, onSuccess)
349351
patternPlan(seqResult, arg, matchTrailing)
350-
matchElemsPlan(getResult, leading, exact = false, matchRemaining)
352+
matchElemsPlan(getResult, leading,
353+
LengthTest(leading.length + trailing.length, exact = false),
354+
matchRemaining)
351355
case _ =>
352-
matchElemsPlan(getResult, args, exact = true, onSuccess)
356+
matchElemsPlan(getResult, args, LengthTest(args.length, exact = true), onSuccess)
353357

354358
/** Plan for matching the sequence in `getResult`
355359
*
@@ -518,7 +522,7 @@ object PatternMatcher {
518522
case WildcardPattern() | This(_) =>
519523
onSuccess
520524
case SeqLiteral(pats, _) =>
521-
matchElemsPlan(scrutinee, pats, exact = true, onSuccess)
525+
matchElemsPlan(scrutinee, pats, LengthTest(pats.length, exact = true), onSuccess)
522526
case _ =>
523527
TestPlan(EqualTest(tree), scrutinee, tree.span, onSuccess)
524528
}

tests/run/spreads.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,12 @@ def useInt(xs: Int*) = ???
2828
use(one, oneTwoThree*, two)
2929
//use(1.0, ao*, 2.0)
3030

31+
val numbers1 = Array(1, 2, 3)
32+
val numbers2 = List(4, 5, 6)
33+
34+
def sum(xs: Int*) = xs.sum
35+
36+
assert(sum(0, numbers1*, numbers2*, 4) == 25)
37+
3138

3239

0 commit comments

Comments
 (0)