@@ -18,13 +18,14 @@ import config.Printers.typr
18
18
import config .Feature
19
19
import util .{SrcPos , Stats }
20
20
import reporting .*
21
- import NameKinds .WildcardParamName
21
+ import NameKinds .{ WildcardParamName , TempResultName }
22
22
import typer .Applications .{spread , HasSpreads }
23
23
import typer .Implicits .SearchFailureType
24
24
import Constants .Constant
25
25
import cc .*
26
26
import dotty .tools .dotc .transform .MacroAnnotations .hasMacroAnnotation
27
27
import dotty .tools .dotc .core .NameKinds .DefaultGetterName
28
+ import ast .TreeInfo
28
29
29
30
object PostTyper {
30
31
val name : String = " posttyper"
@@ -379,6 +380,25 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
379
380
case _ =>
380
381
tpt
381
382
383
+ private def evalSpreadsOnce (trees : List [Tree ])(within : List [Tree ] => Tree )(using Context ): Tree =
384
+ if trees.exists:
385
+ case spread(elem) => ! (exprPurity(elem) >= TreeInfo .Idempotent )
386
+ case _ => false
387
+ then
388
+ val lifted = new mutable.ListBuffer [ValDef ]
389
+ def liftIfImpure (tree : Tree ): Tree = tree match
390
+ case tree @ Apply (fn, args) if fn.symbol == defn.spreadMethod =>
391
+ cpy.Apply (tree)(fn, args.mapConserve(liftIfImpure))
392
+ case _ if tpd.exprPurity(tree) >= TreeInfo .Idempotent =>
393
+ tree
394
+ case _ =>
395
+ val vdef = SyntheticValDef (TempResultName .fresh(), tree)
396
+ lifted += vdef
397
+ Ident (vdef.namedType)
398
+ val pureTrees = trees.mapConserve(liftIfImpure)
399
+ Block (lifted.toList, within(pureTrees))
400
+ else within(trees)
401
+
382
402
/** Translate sequence literal containing spread operators. Example:
383
403
*
384
404
* val xs, ys: List[Int]
@@ -400,50 +420,51 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
400
420
* at typer, we don't have all type variables instantiated yet.
401
421
*/
402
422
private def flattenSpreads [T ](tree : SeqLiteral )(using Context ): Tree =
403
- val SeqLiteral (elems , elemtpt) = tree
423
+ val SeqLiteral (rawElems , elemtpt) = tree
404
424
val elemType = elemtpt.tpe
405
425
val elemCls = elemType.classSymbol
406
426
407
- val lengthCalls = elems.collect:
408
- case spread(elem) => elem.select(nme.length)
409
- val singleElemCount : Tree = Literal (Constant (elems.length - lengthCalls.length))
410
- val totalLength =
411
- lengthCalls.foldLeft(singleElemCount): (acc, len) =>
412
- acc.select(defn.Int_+ ).appliedTo(len)
413
-
414
- def makeBuilder (name : String ) =
415
- ref(defn.ArraySeqBuilderModule ).select(name.toTermName)
416
- def genericBuilder = makeBuilder(" generic" )
417
- .appliedToType(elemType)
418
- .appliedTo(totalLength)
419
-
420
- val builder =
421
- if defn.ScalaValueClasses ().contains(elemCls) then
422
- makeBuilder(s " of ${elemCls.name}" ).appliedTo(totalLength)
423
- else if elemCls.derivesFrom(defn.ObjectClass ) then
424
- val classTagType = defn.ClassTagClass .typeRef.appliedTo(elemType)
425
- val classTag = atPhase(Phases .typerPhase):
426
- ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
427
- classTag.tpe match
428
- case _ : SearchFailureType =>
429
- genericBuilder
430
- case _ =>
431
- makeBuilder(" ofRef" )
432
- .appliedToType(elemType)
433
- .appliedTo(totalLength)
434
- .appliedTo(classTag)
435
- else
436
- genericBuilder
437
-
438
- elems.foldLeft(builder): (bldr, elem) =>
439
- elem match
440
- case spread(arg) =>
441
- val selector =
442
- if arg.tpe.derivesFrom(defn.SeqClass ) then " addSeq"
443
- else " addArray"
444
- bldr.select(selector.toTermName).appliedTo(arg)
445
- case _ => bldr.select(" add" .toTermName).appliedTo(elem)
446
- .select(" result" .toTermName)
427
+ evalSpreadsOnce(rawElems): elems =>
428
+ val lengthCalls = elems.collect:
429
+ case spread(elem) => elem.select(nme.length)
430
+ val singleElemCount : Tree = Literal (Constant (elems.length - lengthCalls.length))
431
+ val totalLength =
432
+ lengthCalls.foldLeft(singleElemCount): (acc, len) =>
433
+ acc.select(defn.Int_+ ).appliedTo(len)
434
+
435
+ def makeBuilder (name : String ) =
436
+ ref(defn.ArraySeqBuilderModule ).select(name.toTermName)
437
+ def genericBuilder = makeBuilder(" generic" )
438
+ .appliedToType(elemType)
439
+ .appliedTo(totalLength)
440
+
441
+ val builder =
442
+ if defn.ScalaValueClasses ().contains(elemCls) then
443
+ makeBuilder(s " of ${elemCls.name}" ).appliedTo(totalLength)
444
+ else if elemCls.derivesFrom(defn.ObjectClass ) then
445
+ val classTagType = defn.ClassTagClass .typeRef.appliedTo(elemType)
446
+ val classTag = atPhase(Phases .typerPhase):
447
+ ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
448
+ classTag.tpe match
449
+ case _ : SearchFailureType =>
450
+ genericBuilder
451
+ case _ =>
452
+ makeBuilder(" ofRef" )
453
+ .appliedToType(elemType)
454
+ .appliedTo(totalLength)
455
+ .appliedTo(classTag)
456
+ else
457
+ genericBuilder
458
+
459
+ elems.foldLeft(builder): (bldr, elem) =>
460
+ elem match
461
+ case spread(arg) =>
462
+ val selector =
463
+ if arg.tpe.derivesFrom(defn.SeqClass ) then " addSeq"
464
+ else " addArray"
465
+ bldr.select(selector.toTermName).appliedTo(arg)
466
+ case _ => bldr.select(" add" .toTermName).appliedTo(elem)
467
+ .select(" result" .toTermName)
447
468
end flattenSpreads
448
469
449
470
override def transform (tree : Tree )(using Context ): Tree =
0 commit comments