Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/config/Feature.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ object Feature:
val modularity = experimental("modularity")
val quotedPatternsWithPolymorphicFunctions = experimental("quotedPatternsWithPolymorphicFunctions")
val packageObjectValues = experimental("packageObjectValues")
val multiSpreads = experimental("multiSpreads")
val subCases = experimental("subCases")

def experimentalAutoEnableFeatures(using Context): List[TermName] =
Expand Down
9 changes: 8 additions & 1 deletion compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ class Definitions {
@tu lazy val throwMethod: TermSymbol = enterMethod(OpsPackageClass, nme.THROWkw,
MethodType(List(ThrowableType), NothingType))

@tu lazy val spreadMethod = enterMethod(OpsPackageClass, nme.spread,
PolyType(TypeBounds.empty :: Nil)(
tl => MethodType(AnyType :: Nil, tl.paramRefs(0))
))

@tu lazy val NothingClass: ClassSymbol = enterCompleteClassSymbol(
ScalaPackageClass, tpnme.Nothing, AbstractFinal, List(AnyType))
def NothingType: TypeRef = NothingClass.typeRef
Expand Down Expand Up @@ -519,6 +524,8 @@ class Definitions {
@tu lazy val newGenericArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newGenericArray")
@tu lazy val newArrayMethod: TermSymbol = DottyArraysModule.requiredMethod("newArray")

@tu lazy val ArraySeqBuilderModule: Symbol = requiredModule("scala.runtime.ArraySeqBuilder")

def getWrapVarargsArrayModule: Symbol = ScalaRuntimeModule

// The set of all wrap{X, Ref}Array methods, where X is a value type
Expand Down Expand Up @@ -2264,7 +2271,7 @@ class Definitions {

/** Lists core methods that don't have underlying bytecode, but are synthesized on-the-fly in every reflection universe */
@tu lazy val syntheticCoreMethods: List[TermSymbol] =
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod)
AnyMethods ++ ObjectMethods ++ List(String_+, throwMethod, spreadMethod)

@tu lazy val reservedScalaClassNames: Set[Name] = syntheticScalaClasses.map(_.name).toSet

Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ object StdNames {
val setSymbol: N = "setSymbol"
val setType: N = "setType"
val setTypeSignature: N = "setTypeSignature"
val spread: N = "spread"
val standardInterpolator: N = "standardInterpolator"
val staticClass : N = "staticClass"
val staticModule : N = "staticModule"
Expand Down
15 changes: 10 additions & 5 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1056,17 +1056,22 @@ object Parsers {
}

/** Is current ident a `*`, and is it followed by a `)`, `, )`, `,EOF`? The latter two are not
syntactically valid, but we need to include them here for error recovery. */
syntactically valid, but we need to include them here for error recovery.
Under experimental.multiSpreads we allow `*`` followed by `,` unconditionally.
*/
def followingIsVararg(): Boolean =
in.isIdent(nme.raw.STAR) && {
val lookahead = in.LookaheadScanner()
lookahead.nextToken()
lookahead.token == RPAREN
|| lookahead.token == COMMA
&& {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
&& (
in.featureEnabled(Feature.multiSpreads)
|| {
lookahead.nextToken()
lookahead.token == RPAREN || lookahead.token == EOF
}
)
}

/** When encountering a `:`, is that in the binding of a lambda?
Expand Down
95 changes: 94 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ import config.Printers.typr
import config.Feature
import util.{SrcPos, Stats}
import reporting.*
import NameKinds.WildcardParamName
import NameKinds.{WildcardParamName, TempResultName}
import typer.Applications.{spread, HasSpreads}
import typer.Implicits.SearchFailureType
import Constants.Constant
import cc.*
import dotty.tools.dotc.transform.MacroAnnotations.hasMacroAnnotation
import dotty.tools.dotc.core.NameKinds.DefaultGetterName
import ast.TreeInfo

object PostTyper {
val name: String = "posttyper"
Expand Down Expand Up @@ -376,6 +380,93 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case _ =>
tpt

private def evalSpreadsOnce(trees: List[Tree])(within: List[Tree] => Tree)(using Context): Tree =
if trees.exists:
case spread(elem) => !(exprPurity(elem) >= TreeInfo.Idempotent)
case _ => false
then
val lifted = new mutable.ListBuffer[ValDef]
def liftIfImpure(tree: Tree): Tree = tree match
case tree @ Apply(fn, args) if fn.symbol == defn.spreadMethod =>
cpy.Apply(tree)(fn, args.mapConserve(liftIfImpure))
case _ if tpd.exprPurity(tree) >= TreeInfo.Idempotent =>
tree
case _ =>
val vdef = SyntheticValDef(TempResultName.fresh(), tree).withSpan(tree.span)
lifted += vdef
Ident(vdef.namedType).withSpan(tree.span)
val pureTrees = trees.mapConserve(liftIfImpure)
Block(lifted.toList, within(pureTrees))
else within(trees)

/** Translate sequence literal containing spread operators. Example:
*
* val xs, ys: List[Int]
* [1, xs*, 2, ys*]
*
* Here the sequence literal is translated at typer to
*
* [1, spread(xs), 2, spread(ys)]
*
* This then translates to
*
* scala.runtime.ArraySeqBuilcder.ofInt(2 + xs.length + ys.length)
* .add(1)
* .addSeq(xs)
* .add(2)
* .addSeq(ys)
*
* The reason for doing a two-step typer/postTyper translation is that
* at typer, we don't have all type variables instantiated yet.
*/
private def flattenSpreads[T](tree: SeqLiteral)(using Context): Tree =
val SeqLiteral(rawElems, elemtpt) = tree
val elemType = elemtpt.tpe
val elemCls = elemType.classSymbol

evalSpreadsOnce(rawElems): elems =>
val lengthCalls = elems.collect:
case spread(elem) => elem.select(nme.length)
val singleElemCount: Tree = Literal(Constant(elems.length - lengthCalls.length))
val totalLength =
lengthCalls.foldLeft(singleElemCount): (acc, len) =>
acc.select(defn.Int_+).appliedTo(len)

def makeBuilder(name: String) =
ref(defn.ArraySeqBuilderModule).select(name.toTermName)
def genericBuilder = makeBuilder("generic")
.appliedToType(elemType)
.appliedTo(totalLength)

val builder =
if defn.ScalaValueClasses().contains(elemCls) then
makeBuilder(s"of${elemCls.name}").appliedTo(totalLength)
else if elemCls.derivesFrom(defn.ObjectClass) then
val classTagType = defn.ClassTagClass.typeRef.appliedTo(elemType)
val classTag = atPhase(Phases.typerPhase):
ctx.typer.inferImplicitArg(classTagType, tree.span.startPos)
classTag.tpe match
case _: SearchFailureType =>
genericBuilder
case _ =>
makeBuilder("ofRef")
.appliedToType(elemType)
.appliedTo(totalLength)
.appliedTo(classTag)
else
genericBuilder

elems.foldLeft(builder): (bldr, elem) =>
elem match
case spread(arg) =>
val selector =
if arg.tpe.derivesFrom(defn.SeqClass) then "addSeq"
else "addArray"
bldr.select(selector.toTermName).appliedTo(arg)
case _ => bldr.select("add".toTermName).appliedTo(elem)
.select("result".toTermName)
end flattenSpreads

override def transform(tree: Tree)(using Context): Tree =
try tree match {
// TODO move CaseDef case lower: keep most probable trees first for performance
Expand Down Expand Up @@ -592,6 +683,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
case tree: RefinedTypeTree =>
Checking.checkPolyFunctionType(tree)
super.transform(tree)
case tree: SeqLiteral if tree.hasAttachment(HasSpreads) =>
flattenSpreads(tree)
case _: Quote | _: QuotePattern =>
ctx.compilationUnit.needsStaging = true
super.transform(tree)
Expand Down
43 changes: 33 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import Inferencing.*
import reporting.*
import Nullables.*, NullOpsDecorator.*
import config.{Feature, MigrationVersion, SourceVersion}
import util.Property

import collection.mutable
import config.Printers.{overload, typr, unapp}
Expand All @@ -42,6 +43,17 @@ import dotty.tools.dotc.inlines.Inlines
object Applications {
import tpd.*

/** Attachment key for SeqLiterals containing spreads. Eliminated at PostTyper */
val HasSpreads = new Property.StickyKey[Unit]

/** An extractor for spreads in sequence literals */
object spread:
def apply(arg: Tree, elemtpt: Tree)(using Context) =
ref(defn.spreadMethod).appliedToTypeTree(elemtpt).appliedTo(arg)
def unapply(arg: Apply)(using Context): Option[Tree] = arg match
case Apply(fn, x :: Nil) if fn.symbol == defn.spreadMethod => Some(x)
case _ => None

def extractorMember(tp: Type, name: Name)(using Context): SingleDenotation =
tp.member(name).suchThat(sym => sym.info.isParameterless && sym.info.widenExpr.isValueType)

Expand Down Expand Up @@ -797,14 +809,19 @@ trait Applications extends Compatibility {
addTyped(arg)
case _ =>
val elemFormal = formal.widenExpr.argTypesLo.head
val typedArgs =
harmonic(harmonizeArgs, elemFormal) {
args.map { arg =>
if Feature.enabled(Feature.multiSpreads)
&& !ctx.isAfterTyper && args.exists(isVarArg)
then
args.foreach: arg =>
if isVarArg(arg)
then addArg(typedArg(arg, formal), formal)
else addArg(typedArg(arg, elemFormal), elemFormal)
else
val typedArgs = harmonic(harmonizeArgs, elemFormal):
args.map: arg =>
checkNoVarArg(arg)
typedArg(arg, elemFormal)
}
}
typedArgs.foreach(addArg(_, elemFormal))
typedArgs.foreach(addArg(_, elemFormal))
makeVarArg(args.length, elemFormal)
}
else args match {
Expand Down Expand Up @@ -944,12 +961,18 @@ trait Applications extends Compatibility {
typedArgBuf += typedArg
ok = ok & !typedArg.tpe.isError

def makeVarArg(n: Int, elemFormal: Type): Unit = {
def makeVarArg(n: Int, elemFormal: Type): Unit =
val args = typedArgBuf.takeRight(n).toList
typedArgBuf.dropRightInPlace(n)
val elemtpt = TypeTree(elemFormal.normalizedTupleType, inferred = true)
typedArgBuf += seqToRepeated(SeqLiteral(args, elemtpt))
}
val elemTpe = elemFormal.normalizedTupleType
val elemtpt = TypeTree(elemTpe, inferred = true)
def wrapSpread(arg: Tree): Tree = arg match
case Typed(argExpr, tpt) if tpt.tpe.isRepeatedParam => spread(argExpr, elemtpt)
case _ => arg
val args1 = args.mapConserve(wrapSpread)
val seqLit = SeqLiteral(args1, elemtpt)
if args1 ne args then seqLit.putAttachment(HasSpreads, ())
typedArgBuf += seqToRepeated(seqLit)

def harmonizeArgs(args: List[TypedArg]): List[Tree] =
// harmonize args only if resType depends on parameter types
Expand Down
Empty file.
6 changes: 5 additions & 1 deletion library/src/scala/language.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,15 @@ object language {
@compileTimeOnly("`packageObjectValues` can only be used at compile time in import statements")
object packageObjectValues

/** Experimental support for multiple spread arguments.
*/
@compileTimeOnly("`multiSpreads` can only be used at compile time in import statements")
object multiSpreads

/** Experimental support for match expressions with sub cases.
*/
@compileTimeOnly("`subCases` can only be used at compile time in import statements")
object subCases

}

/** The deprecated object contains features that are no longer officially suypported in Scala.
Expand Down
Loading
Loading