Skip to content

Commit 59ec380

Browse files
committed
Allow multiple spreads in function arguments
1 parent ca400bd commit 59ec380

File tree

16 files changed

+413
-19
lines changed

16 files changed

+413
-19
lines changed

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ object Feature:
3737
val modularity = experimental("modularity")
3838
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
3939
val packageObjectValues = experimental("packageObjectValues")
40+
val multiSpreads = experimental("multiSpreads")
4041
val subCases = experimental("subCases")
4142

4243
def experimentalAutoEnableFeatures(using Context): List[TermName] =

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,11 @@ class Definitions {
468468
@tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw,
469469
MethodType(List(ThrowableType), NothingType))
470470

471+
@tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread,
472+
PolyType(TypeBounds.empty :: Nil)(
473+
tl => MethodType(AnyType :: Nil, tl.paramRefs(0))
474+
))
475+
471476
@tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol(
472477
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
473478
def NothingType: TypeRef = NothingClass.typeRef
@@ -519,6 +524,8 @@ class Definitions {
519524
@tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray")
520525
@tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray")
521526

527+
@tu lazy val ArraySeqBuilderModule: Symbol = requiredModule("scala.runtime.ArraySeqBuilder")
528+
522529
def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule
523530

524531
// The set of all wrap{X, Ref}Array methods, where X is a value type
@@ -2264,7 +2271,7 @@ class Definitions {
22642271

22652272
/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
22662273
@tu lazy val syntheticCoreMethods: List[TermSymbol] =
2267-
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
2274+
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod)
22682275

22692276
@tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet
22702277

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ object StdNames {
619619
val setSymbol: N = "setSymbol"
620620
val setType: N = "setType"
621621
val setTypeSignature: N = "setTypeSignature"
622+
val spread: N = "spread"
622623
val standardInterpolator: N = "standardInterpolator"
623624
val staticClass : N = "staticClass"
624625
val staticModule : N = "staticModule"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,17 +1056,22 @@ object Parsers {
10561056
}
10571057

10581058
/** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not
1059-
syntactically valid, but we need to include them here for error recovery. */
1059+
syntactically valid, but we need to include them here for error recovery.
1060+
Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally.
1061+
*/
10601062
def followingIsVararg(): Boolean =
10611063
in.isIdent(nme.raw.STAR) && {
10621064
val lookahead = in.LookaheadScanner()
10631065
lookahead.nextToken()
10641066
lookahead.token == RPAREN
10651067
|| lookahead.token == COMMA
1066-
&& {
1067-
lookahead.nextToken()
1068-
lookahead.token == RPAREN || lookahead.token == EOF
1069-
}
1068+
&& (
1069+
in.featureEnabled(Feature.multiSpreads)
1070+
|| {
1071+
lookahead.nextToken()
1072+
lookahead.token == RPAREN || lookahead.token == EOF
1073+
}
1074+
)
10701075
}
10711076

10721077
/** When encountering a `:`, is that in the binding of a lambda?

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ import config.Feature
1919
import util.{SrcPos, Stats}
2020
import reporting.*
2121
import NameKinds.WildcardParamName
22+
import typer.Applications.{spread, HasSpreads}
23+
import typer.Implicits.SearchFailureType
24+
import Constants.Constant
2225
import cc.*
2326
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
2427
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
@@ -376,6 +379,73 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
376379
case _ =>
377380
tpt
378381

382+
/** Translate sequence literal containing spread operators. Example:
383+
*
384+
* val xs, ys: List[Int]
385+
* [1, xs*, 2, ys*]
386+
*
387+
* Here the sequence literal is translated at typer tp
388+
*
389+
* [1, spread(xs), 2, spread(ys)]
390+
*
391+
* This then translates to
392+
*
393+
* scala.runtime.ArraySeqBuilcder.ofInt(2 + xs.length + ys.length)
394+
* .add(1)
395+
* .addSeq(xs)
396+
* .add(2)
397+
* .addSeq(ys)
398+
*
399+
* The reason for doing a two-step typer/postTyper translation is that
400+
* at typer, we don't have all type variables instantiated yet.
401+
*/
402+
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
403+
val SeqLiteral(elems, elemtpt) = tree
404+
val elemType = elemtpt.tpe
405+
val elemCls = elemType.classSymbol
406+
407+
val lengthCalls = elems.collect:
408+
case spread(elem) => elem.select(nme.length)
409+
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
410+
val totalLength =
411+
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
412+
acc.select(defn.Int_+).appliedTo(len)
413+
414+
def makeBuilder(name: String) =
415+
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
416+
def genericBuilder = makeBuilder("generic")
417+
.appliedToType(elemType)
418+
.appliedTo(totalLength)
419+
420+
val builder =
421+
if defn.ScalaValueClasses().contains(elemCls) then
422+
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
423+
else if elemCls.derivesFrom(defn.ObjectClass) then
424+
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
425+
val classTag = atPhase(Phases.typerPhase):
426+
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
427+
classTag.tpe match
428+
case _: SearchFailureType =>
429+
genericBuilder
430+
case _ =>
431+
makeBuilder("ofRef")
432+
.appliedToType(elemType)
433+
.appliedTo(totalLength)
434+
.appliedTo(classTag)
435+
else
436+
genericBuilder
437+
438+
elems.foldLeft(builder): (bldr, elem) =>
439+
elem match
440+
case spread(arg) =>
441+
val selector =
442+
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
443+
else "addArray"
444+
bldr.select(selector.toTermName).appliedTo(arg)
445+
case _ => bldr.select("add".toTermName).appliedTo(elem)
446+
.select("result".toTermName)
447+
end flattenSpreads
448+
379449
override def transform(tree: Tree)(using Context): Tree =
380450
try tree match {
381451
// TODO move CaseDef case lower: keep most probable trees first for performance
@@ -592,6 +662,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
592662
case tree: RefinedTypeTree =>
593663
Checking.checkPolyFunctionType(tree)
594664
super.transform(tree)
665+
case tree: SeqLiteral if tree.hasAttachment(HasSpreads) =>
666+
flattenSpreads(tree)
595667
case _: Quote | _: QuotePattern =>
596668
ctx.compilationUnit.needsStaging = true
597669
super.transform(tree)

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import Inferencing.*
2424
import reporting.*
2525
import Nullables.*, NullOpsDecorator.*
2626
import config.{Feature, MigrationVersion, SourceVersion}
27+
import util.Property
2728

2829
import collection.mutable
2930
import config.Printers.{overload, typr, unapp}
@@ -42,6 +43,17 @@ import dotty.tools.dotc.inlines.Inlines
4243
object Applications {
4344
import tpd.*
4445

46+
/** Attachment key for SeqLiterals containing spreads. Eliminated at PostTyper */
47+
val HasSpreads = new Property.StickyKey[Unit]
48+
49+
/** An extractor for spreads in sequence literals */
50+
object spread:
51+
def apply(arg: Tree, elemtpt: Tree)(using Context) =
52+
ref(defn.spreadMethod).appliedToTypeTree(elemtpt).appliedTo(arg)
53+
def unapply(arg: Apply)(using Context): Option[Tree] = arg match
54+
case Apply(fn, x :: Nil) if fn.symbol == defn.spreadMethod => Some(x)
55+
case _ => None
56+
4557
def extractorMember(tp: Type, name: Name)(using Context): SingleDenotation =
4658
tp.member(name).suchThat(sym => sym.info.isParameterless && sym.info.widenExpr.isValueType)
4759

@@ -797,14 +809,19 @@ trait Applications extends Compatibility {
797809
addTyped(arg)
798810
case _ =>
799811
val elemFormal = formal.widenExpr.argTypesLo.head
800-
val typedArgs =
801-
harmonic(harmonizeArgs, elemFormal) {
802-
args.map { arg =>
812+
if Feature.enabled(Feature.multiSpreads)
813+
&& !ctx.isAfterTyper && args.exists(isVarArg)
814+
then
815+
args.foreach: arg =>
816+
if isVarArg(arg)
817+
then addArg(typedArg(arg, formal), formal)
818+
else addArg(typedArg(arg, elemFormal), elemFormal)
819+
else
820+
val typedArgs = harmonic(harmonizeArgs, elemFormal):
821+
args.map: arg =>
803822
checkNoVarArg(arg)
804823
typedArg(arg, elemFormal)
805-
}
806-
}
807-
typedArgs.foreach(addArg(_, elemFormal))
824+
typedArgs.foreach(addArg(_, elemFormal))
808825
makeVarArg(args.length, elemFormal)
809826
}
810827
else args match {
@@ -944,12 +961,18 @@ trait Applications extends Compatibility {
944961
typedArgBuf += typedArg
945962
ok = ok & !typedArg.tpe.isError
946963

947-
def makeVarArg(n: Int, elemFormal: Type): Unit = {
948-
val args = typedArgBuf.takeRight(n).toList
964+
def makeVarArg(n: Int, elemFormal: Type): Unit =
965+
var args = typedArgBuf.takeRight(n).toList
949966
typedArgBuf.dropRightInPlace(n)
950-
val elemtpt = TypeTree(elemFormal.normalizedTupleType, inferred = true)
951-
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
952-
}
967+
val elemTpe = elemFormal.normalizedTupleType
968+
val elemtpt = TypeTree(elemTpe, inferred = true)
969+
def wrapSpread(arg: Tree): Tree = arg match
970+
case Typed(argExpr, tpt) if tpt.tpe.isRepeatedParam => spread(argExpr, elemtpt)
971+
case _ => arg
972+
val args1 = args.mapConserve(wrapSpread)
973+
val seqLit = SeqLiteral(args1, elemtpt)
974+
if args1 ne args then seqLit.putAttachment(HasSpreads, ())
975+
typedArgBuf += seqToRepeated(seqLit)
953976

954977
def harmonizeArgs(args: List[TypedArg]): List[Tree] =
955978
// harmonize args only if resType depends on parameter types

library/src/scala/compiletime/Spread.scala

Whitespace-only changes.

library/src/scala/language.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,15 @@ object language {
350350
@compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements")
351351
object packageObjectValues
352352

353+
/** Experimental support for multiple spread arguments.
354+
*/
355+
@compileTimeOnly("`multiSpreads` can only be used at compile time in import statements")
356+
object multiSpreads
357+
353358
/** Experimental support for match expressions with sub cases.
354359
*/
355360
@compileTimeOnly("`subCases` can only be used at compile time in import statements")
356361
object subCases
357-
358362
}
359363

360364
/** The deprecated object contains features that are no longer officially suypported in Scala.

0 commit comments

Comments
 (0)