Skip to content

Commit 5a233c9

Browse files
oderskynicolasstucki
authored andcommitted
Eliminate polyDefDef def and calls
1 parent d4d604d commit 5a233c9

File tree

15 files changed

+73
-123
lines changed

15 files changed

+73
-123
lines changed

compiler/src-bootstrapped/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
257257

258258
object DefDef extends DefDefModule:
259259
def apply(symbol: Symbol, rhsFn: List[TypeRepr] => List[List[Term]] => Option[Term]): DefDef =
260-
withDefaultPos(tpd.polyDefDef(symbol.asTerm,
261-
tparams => vparamss => yCheckedOwners(rhsFn(tparams.map(_.tpe))(vparamss), symbol).getOrElse(tpd.EmptyTree)))
260+
withDefaultPos(tpd.DefDef(symbol.asTerm, prefss => {
261+
val (tparams, vparamss) = tpd.splitArgs(prefss)
262+
yCheckedOwners(rhsFn(tparams.map(_.tpe))(vparamss), symbol).getOrElse(tpd.EmptyTree)
263+
}))
262264
def copy(original: Tree)(name: String, typeParams: List[TypeDef], paramss: List[List[ValDef]], tpt: TypeTree, rhs: Option[Term]): DefDef =
263265
tpd.cpy.DefDef(original)(name.toTermName, tpd.joinParams(typeParams, paramss), tpt, yCheckedOwners(rhs, original.symbol).getOrElse(tpd.EmptyTree))
264266
def unapply(ddef: DefDef): (String, List[TypeDef], List[List[ValDef]], TypeTree, Option[Term]) =

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 20 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
224224
def DefDef(sym: TermSymbol, rhs: Tree = EmptyTree)(using Context): DefDef =
225225
ta.assignType(DefDef(sym, Function.const(rhs) _), sym)
226226

227+
/** A DefDef with given method symbol `sym`.
228+
* @rhsFn A function from parameter references
229+
* to the method's right-hand side.
230+
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
231+
* are freshly generated if `rawParamss` is empty.
232+
*/
227233
def DefDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Tree)(using Context): DefDef =
228234

229235
// Map method type `tp` with remaining parameters stored in rawParamss to
@@ -277,69 +283,6 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
277283
DefDef(sym, paramss, rtp, rhsFn(paramss.nestedMap(ref)))
278284
end DefDef
279285

280-
/** A DefDef with given method symbol `sym`.
281-
* @rhsFn A function from type parameter types and term parameter references
282-
* to the method's right-hand side.
283-
* Parameter symbols are taken from the `rawParamss` field of `sym`, or
284-
* are freshly generated if `rawParamss` is empty.
285-
*/
286-
def polyDefDef(sym: TermSymbol, rhsFn: List[Tree] => List[List[Tree]] => Tree)(using Context): DefDef = {
287-
288-
val (tparams, existingParamss, mtp) = sym.info match {
289-
case tp: PolyType =>
290-
val (tparams, existingParamss) = sym.rawParamss match
291-
case tparams :: vparamss =>
292-
assert(tparams.hasSameLengthAs(tp.paramNames) && tparams.head.isType)
293-
(tparams.asInstanceOf[List[TypeSymbol]], vparamss)
294-
case _ =>
295-
(newTypeParams(sym, tp.paramNames, EmptyFlags, tp.instantiateParamInfos(_)), Nil)
296-
(tparams, existingParamss, tp.instantiate(tparams map (_.typeRef)))
297-
case tp => (Nil, sym.rawParamss, tp)
298-
}
299-
300-
def valueParamss(tp: Type, existingParamss: List[List[Symbol]]): (List[List[TermSymbol]], Type) = tp match {
301-
case tp: MethodType =>
302-
val isParamDependent = tp.isParamDependent
303-
val previousParamRefs = if (isParamDependent) mutable.ListBuffer[TermRef]() else null
304-
305-
def valueParam(name: TermName, origInfo: Type): TermSymbol = {
306-
val maybeImplicit =
307-
if (tp.isContextualMethod) Given
308-
else if (tp.isImplicitMethod) Implicit
309-
else EmptyFlags
310-
val maybeErased = if (tp.isErasedMethod) Erased else EmptyFlags
311-
312-
def makeSym(info: Type) = newSymbol(sym, name, TermParam | maybeImplicit | maybeErased, info, coord = sym.coord)
313-
314-
if (isParamDependent) {
315-
val sym = makeSym(origInfo.substParams(tp, previousParamRefs.toList))
316-
previousParamRefs += sym.termRef
317-
sym
318-
}
319-
else
320-
makeSym(origInfo)
321-
}
322-
323-
val (params, existingParamss1) =
324-
if tp.paramInfos.isEmpty then (Nil, existingParamss)
325-
else existingParamss match
326-
case vparams :: existingParamss1 =>
327-
assert(vparams.hasSameLengthAs(tp.paramNames) && vparams.head.isTerm)
328-
(vparams.asInstanceOf[List[TermSymbol]], existingParamss1)
329-
case _ =>
330-
(tp.paramNames.lazyZip(tp.paramInfos).map(valueParam), Nil)
331-
val (paramss, rtp) =
332-
valueParamss(tp.instantiate(params map (_.termRef)), existingParamss1)
333-
(params :: paramss, rtp)
334-
case tp => (Nil, tp.widenExpr)
335-
}
336-
val (vparamss, rtp) = valueParamss(mtp, existingParamss)
337-
val targs = tparams.map(tparam => ref(tparam.typeRef))
338-
val argss = vparamss.nestedMap(vparam => Ident(vparam.termRef))
339-
sym.setParamss(tparams :: vparamss)
340-
DefDef(sym, joinSymbols(tparams, vparamss), rtp, rhsFn(targs)(argss))
341-
}
342-
343286
def TypeDef(sym: TypeSymbol)(using Context): TypeDef =
344287
ta.assignType(untpd.TypeDef(sym.name, TypeTree(sym.info)), sym)
345288

@@ -406,7 +349,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
406349
for overridden <- fwdMeth.allOverriddenSymbols do
407350
if overridden.is(Extension) then fwdMeth.setFlag(Extension)
408351
if !overridden.is(Deferred) then fwdMeth.setFlag(Override)
409-
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypeTrees(tprefs).appliedToArgss(prefss))
352+
DefDef(fwdMeth, ref(fn).appliedToArgss(_))
410353
}
411354
val forwarders = fns.lazyZip(methNames).map(forwarder)
412355
val cdef = ClassDef(cls, DefDef(constr), forwarders)
@@ -1287,12 +1230,21 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
12871230
Ident(defn.ScalaRuntimeModule.requiredMethod(name).termRef).appliedToTermArgs(args)
12881231

12891232
/** An extractor that pulls out type arguments */
1290-
object MaybePoly {
1291-
def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match {
1233+
object MaybePoly:
1234+
def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match
12921235
case TypeApply(tree, targs) => Some(tree, targs)
12931236
case _ => Some(tree, Nil)
1294-
}
1295-
}
1237+
1238+
object TypeArgs:
1239+
def unapply(ts: List[Tree]): Option[List[Tree]] =
1240+
if ts.nonEmpty && ts.head.isType then Some(ts) else None
1241+
1242+
/** Split argument clauses into a leading type argument clause if it exists and
1243+
* remaining clauses
1244+
*/
1245+
def splitArgs(argss: List[List[Tree]]): (List[Tree], List[List[Tree]]) = argss match
1246+
case TypeArgs(targs) :: argss1 => (targs, argss1)
1247+
case _ => (Nil, argss)
12961248

12971249
/** A key to be used in a context property that tracks enclosing inlined calls */
12981250
private val InlinedCalls = Property.Key[List[Tree]]()

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -877,9 +877,6 @@ object Symbols {
877877
case (x: Symbol) :: _ if x.isType => Some(xs.asInstanceOf[List[TypeSymbol]])
878878
case _ => None
879879

880-
def joinSymbols(xs: List[Symbol], ys: List[List[Symbol]]): List[List[Symbol]] =
881-
if xs.isEmpty then ys else xs :: ys
882-
883880
// ----- Locating predefined symbols ----------------------------------------
884881

885882
def requiredPackage(path: PreName)(using Context): TermSymbol = {

compiler/src/dotty/tools/dotc/transform/AccessProxies.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,18 @@ abstract class AccessProxies {
3737
*/
3838
private def accessorDefs(cls: Symbol)(using Context): Iterator[DefDef] =
3939
for (accessor <- cls.info.decls.iterator; accessed <- accessedBy.remove(accessor).toOption) yield
40-
polyDefDef(accessor.asTerm, tps => argss => {
40+
DefDef(accessor.asTerm, prefss => {
4141
def numTypeParams = accessed.info match {
4242
case info: PolyType => info.paramNames.length
4343
case _ => 0
4444
}
45+
val (targs, argss) = splitArgs(prefss)
4546
val (accessRef, forwardedTpts, forwardedArgss) =
4647
if (passReceiverAsArg(accessor.name))
47-
(argss.head.head.select(accessed), tps.takeRight(numTypeParams), argss.tail)
48+
(argss.head.head.select(accessed), targs.takeRight(numTypeParams), argss.tail)
4849
else
4950
(if (accessed.isStatic) ref(accessed) else ref(TermRef(cls.thisType, accessed)),
50-
tps, argss)
51+
targs, argss)
5152
val rhs =
5253
if (accessor.name.isSetterName &&
5354
forwardedArgss.nonEmpty && forwardedArgss.head.nonEmpty) // defensive conditions

compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,12 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
222222
.get
223223
.symbol.asTerm
224224
// Generate the method
225-
val forwarderDef = polyDefDef(forwarderSym, trefs => vrefss => {
226-
val init :+ (last :+ vararg) = vrefss
225+
val forwarderDef = DefDef(forwarderSym, prefss => {
226+
val init :+ (last :+ vararg) = prefss
227227
// Can't call `.argTypes` here because the underlying array type is of the
228228
// form `Array[? <: SomeType]`, so we need `.argInfos` to get the `TypeBounds`.
229229
val elemtp = vararg.tpe.widen.argInfos.head
230230
ref(sym.termRef)
231-
.appliedToTypeTrees(trefs)
232231
.appliedToArgss(init)
233232
.appliedToTermArgs(last :+ wrapArray(vararg, elemtp))
234233
})

compiler/src/dotty/tools/dotc/transform/FirstTransform.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,14 @@ class FirstTransform extends MiniPhase with InfoTransformer { thisPhase =>
117117
override def transformTemplate(impl: Template)(using Context): Tree =
118118
cpy.Template(impl)(self = EmptyValDef)
119119

120-
override def transformDefDef(ddef: DefDef)(using Context): Tree = {
120+
override def transformDefDef(ddef: DefDef)(using Context): Tree =
121121
val meth = ddef.symbol.asTerm
122-
if (meth.hasAnnotation(defn.NativeAnnot)) {
122+
if meth.hasAnnotation(defn.NativeAnnot) then
123123
meth.resetFlag(Deferred)
124-
polyDefDef(meth,
125-
_ => _ => ref(defn.Sys_error.termRef).withSpan(ddef.span)
124+
DefDef(meth, _ =>
125+
ref(defn.Sys_error.termRef).withSpan(ddef.span)
126126
.appliedTo(Literal(Constant(s"native method stub"))))
127-
}
128-
129127
else ddef
130-
}
131128

132129
override def transformStats(trees: List[Tree])(using Context): List[Tree] =
133130
ast.Trees.flatten(atPhase(thisPhase.next)(reorderAndComplete(trees)))

compiler/src/dotty/tools/dotc/transform/FullParameterization.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ trait FullParameterization {
148148
* of class that contained original defDef
149149
*/
150150
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(using Context): Tree =
151-
polyDefDef(derived, trefs => vrefss => {
151+
DefDef(derived, prefss => {
152+
val (trefs, vrefss) = splitArgs(prefss)
152153
val origMeth = originalDef.symbol
153154
val origClass = origMeth.enclosingClass.asClass
154155
val origLeadingTypeParamSyms = allInstanceTypeParams(originalDef, abstractOverClass)

compiler/src/dotty/tools/dotc/transform/HoistSuperArgs.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,9 @@ class HoistSuperArgs extends MiniPhase with IdentityDenotTransformer { thisPhase
129129
cpy.Apply(arg)(fn, hoistSuperArg(arg1, cdef) :: Nil)
130130
case _ if arg.existsSubTree(needsHoist) =>
131131
val superMeth = newSuperArgMethod(arg.tpe)
132-
val superArgDef = polyDefDef(superMeth, trefs => vrefss => {
133-
val paramSyms = trefs.map(_.tpe.typeSymbol) ::: vrefss.flatten.map(_.symbol)
132+
val superArgDef = DefDef(superMeth, prefss => {
133+
val paramSyms = prefss.flatten.map(pref =>
134+
if pref.isType then pref.tpe.typeSymbol else pref.symbol)
134135
val tmap = new TreeTypeMap(
135136
typeMap = new TypeMap {
136137
lazy val origToParam = origParams.zip(paramSyms).toMap

compiler/src/dotty/tools/dotc/transform/Mixin.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class Mixin extends MiniPhase with SymTransformer { thisPhase =>
285285
for (meth <- mixin.info.decls.toList if needsMixinForwarder(meth))
286286
yield {
287287
util.Stats.record("mixin forwarders")
288-
transformFollowing(polyDefDef(mkForwarderSym(meth.asTerm, Bridge), forwarderRhsFn(meth)))
288+
transformFollowing(DefDef(mkForwarderSym(meth.asTerm, Bridge), forwarderRhsFn(meth)))
289289
}
290290

291291
cpy.Template(impl)(

compiler/src/dotty/tools/dotc/transform/MixinOps.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,18 +77,17 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(using Context) {
7777
final val PrivateOrAccessor: FlagSet = Private | Accessor
7878
final val PrivateOrAccessorOrDeferred: FlagSet = Private | Accessor | Deferred
7979

80-
def forwarderRhsFn(target: Symbol): List[Tree] => List[List[Tree]] => Tree = {
81-
targs => vrefss =>
80+
def forwarderRhsFn(target: Symbol): List[List[Tree]] => Tree =
81+
prefss =>
82+
val (targs, vargss) = splitArgs(prefss)
8283
val tapp = superRef(target).appliedToTypeTrees(targs)
83-
vrefss match {
84+
vargss match
8485
case Nil | List(Nil) =>
8586
// Overriding is somewhat loose about `()T` vs `=> T`, so just pick
8687
// whichever makes sense for `target`
8788
tapp.ensureApplied
8889
case _ =>
89-
tapp.appliedToArgss(vrefss)
90-
}
91-
}
90+
tapp.appliedToArgss(vargss)
9291

9392
private def competingMethodsIterator(meth: Symbol): Iterator[Symbol] =
9493
cls.baseClasses.iterator

0 commit comments

Comments
 (0)