Skip to content

Commit 1f333bc

Browse files
committed
Make sure spreads are evaluated only once
We need to access them twice because we first need to take their length, then append them to the buffer. If a spread might have side effects, lift all side-effecting arguments out in the order of occurrence.
1 parent 0bc9f43 commit 1f333bc

File tree

3 files changed

+86
-42
lines changed

3 files changed

+86
-42
lines changed

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ import config.Printers.typr
1818
import config.Feature
1919
import util.{SrcPos, Stats}
2020
import reporting.*
21-
import NameKinds.WildcardParamName
21+
import NameKinds.{WildcardParamName, TempResultName}
2222
import typer.Applications.{spread, HasSpreads}
2323
import typer.Implicits.SearchFailureType
2424
import Constants.Constant
2525
import cc.*
2626
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
2727
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
28+
import ast.TreeInfo
2829

2930
object PostTyper {
3031
val name: String = "posttyper"
@@ -379,6 +380,25 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
379380
case _ =>
380381
tpt
381382

383+
private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree =
384+
if trees.exists:
385+
case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent)
386+
case _ => false
387+
then
388+
val lifted = new mutable.ListBuffer[ValDef]
389+
def liftIfImpure(tree: Tree): Tree = tree match
390+
case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod =>
391+
cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure))
392+
case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent =>
393+
tree
394+
case _ =>
395+
val vdef = SyntheticValDef(TempResultName.fresh(), tree)
396+
lifted += vdef
397+
Ident(vdef.namedType)
398+
val pureTrees = trees.mapConserve(liftIfImpure)
399+
Block(lifted.toList, within(pureTrees))
400+
else within(trees)
401+
382402
/** Translate sequence literal containing spread operators. Example:
383403
*
384404
* val xs, ys: List[Int]
@@ -400,50 +420,51 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
400420
* at typer, we don't have all type variables instantiated yet.
401421
*/
402422
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
403-
val SeqLiteral(elems, elemtpt) = tree
423+
val SeqLiteral(rawElems, elemtpt) = tree
404424
val elemType = elemtpt.tpe
405425
val elemCls = elemType.classSymbol
406426

407-
val lengthCalls = elems.collect:
408-
case spread(elem) => elem.select(nme.length)
409-
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
410-
val totalLength =
411-
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
412-
acc.select(defn.Int_+).appliedTo(len)
413-
414-
def makeBuilder(name: String) =
415-
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
416-
def genericBuilder = makeBuilder("generic")
417-
.appliedToType(elemType)
418-
.appliedTo(totalLength)
419-
420-
val builder =
421-
if defn.ScalaValueClasses().contains(elemCls) then
422-
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
423-
else if elemCls.derivesFrom(defn.ObjectClass) then
424-
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
425-
val classTag = atPhase(Phases.typerPhase):
426-
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
427-
classTag.tpe match
428-
case _: SearchFailureType =>
429-
genericBuilder
430-
case _ =>
431-
makeBuilder("ofRef")
432-
.appliedToType(elemType)
433-
.appliedTo(totalLength)
434-
.appliedTo(classTag)
435-
else
436-
genericBuilder
437-
438-
elems.foldLeft(builder): (bldr, elem) =>
439-
elem match
440-
case spread(arg) =>
441-
val selector =
442-
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
443-
else "addArray"
444-
bldr.select(selector.toTermName).appliedTo(arg)
445-
case _ => bldr.select("add".toTermName).appliedTo(elem)
446-
.select("result".toTermName)
427+
evalSpreadsOnce(rawElems): elems =>
428+
val lengthCalls = elems.collect:
429+
case spread(elem) => elem.select(nme.length)
430+
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
431+
val totalLength =
432+
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
433+
acc.select(defn.Int_+).appliedTo(len)
434+
435+
def makeBuilder(name: String) =
436+
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
437+
def genericBuilder = makeBuilder("generic")
438+
.appliedToType(elemType)
439+
.appliedTo(totalLength)
440+
441+
val builder =
442+
if defn.ScalaValueClasses().contains(elemCls) then
443+
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
444+
else if elemCls.derivesFrom(defn.ObjectClass) then
445+
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
446+
val classTag = atPhase(Phases.typerPhase):
447+
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
448+
classTag.tpe match
449+
case _: SearchFailureType =>
450+
genericBuilder
451+
case _ =>
452+
makeBuilder("ofRef")
453+
.appliedToType(elemType)
454+
.appliedTo(totalLength)
455+
.appliedTo(classTag)
456+
else
457+
genericBuilder
458+
459+
elems.foldLeft(builder): (bldr, elem) =>
460+
elem match
461+
case spread(arg) =>
462+
val selector =
463+
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
464+
else "addArray"
465+
bldr.select(selector.toTermName).appliedTo(arg)
466+
case _ => bldr.select("add".toTermName).appliedTo(elem)
467+
.select("result".toTermName)
447468
end flattenSpreads
448469

449470
override def transform(tree: Tree)(using Context): Tree =

tests/run/spreads.check

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
ArraySeq(1, 2, 3)
2+
ArraySeq(1, 2, 3)
3+
ArraySeq(1, 2, 1, 2, 3)
4+
ArraySeq(1, 2, 1, 2, 3)
5+
ArraySeq(1, 1, 2, 3, 2)
6+
ArraySeq(1, 1, 2, 3, 2, 1, 2, 3, 3)
7+
ArraySeq(1, 1, 2, 3, true, A, false)
8+
ArraySeq(1, 1, 2, 3, 2)
9+
one
10+
one-two-three
11+
two
12+
ArraySeq(1, 1, 2, 3, 2)

tests/run/spreads.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,20 @@ def useInt(xs: Int*) = ???
1313

1414
val xs = List(1, 2, 3)
1515
val ys = List("A")
16+
val ao = Option(1.0).toList
1617

1718
val x: Unit = use[Int](1, 2, xs*)
1819
val y = use(1, 2, xs*)
1920
use(1, xs*, 2)
2021
use(1, xs*, 2, xs*, 3)
2122
use(1, xs*, true, ys*, false)
23+
use(1, identity(xs)*, 2)
24+
25+
def one = { println("one"); 1 }
26+
def two = { println("two"); 2 }
27+
def oneTwoThree = { println("one-two-three"); xs }
28+
use(one, oneTwoThree*, two)
29+
//use(1.0, ao*, 2.0)
30+
31+
32+

0 commit comments

Comments
 (0)