@@ -22,6 +22,7 @@ import scala.collection.mutable
2222import scala .util .boundary , boundary .break
2323import dotty .tools .dotc .core .StdNames .nme
2424import dotty .tools .unreachable
25+ import dotty .tools .dotc .util .Spans .Span
2526
2627/** Implementation of SIP-61.
2728 * Runs when `@unroll` annotations are found in a compilation unit, installing new definitions
@@ -33,16 +34,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
3334
3435 import tpd .*
3536
36- private var _unrolledDefs : util.EqHashMap [Symbol , ComputedIndicies ] | Null = null
37- private def initializeUnrolledDefs (): util.EqHashMap [Symbol , ComputedIndicies ] =
38- val local = _unrolledDefs
39- if local == null then
40- val map = new util.EqHashMap [Symbol , ComputedIndicies ]
41- _unrolledDefs = map
42- map
43- else
44- local.clear()
45- local
37+ private val _unrolledDefs : util.EqHashMap [Symbol , ComputedIndices ] = new util.EqHashMap [Symbol , ComputedIndices ]
38+ private def initializeUnrolledDefs (): util.EqHashMap [Symbol , ComputedIndices ] =
39+ _unrolledDefs.clear()
40+ _unrolledDefs
4641
4742 override def phaseName : String = UnrollDefinitions .name
4843
@@ -55,15 +50,15 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
5550 super .run // create and run the transformer on the current compilation unit
5651
5752 def newTransformer (using Context ): Transformer =
58- UnrollingTransformer (ctx.compilationUnit.unrolledClasses.nn )
53+ UnrollingTransformer (ctx.compilationUnit.unrolledClasses)
5954
60- type ComputedIndicies = List [(Int , List [Int ])]
61- type ComputeIndicies = Context ?=> Symbol => ComputedIndicies
55+ type ComputedIndices = List [(Int , List [Int ])]
56+ type ComputeIndices = Context ?=> Symbol => ComputedIndices
6257
63- private class UnrollingTransformer (classes : Set [Symbol ]) extends Transformer {
58+ private class UnrollingTransformer (unrolledClasses : Set [Symbol ]) extends Transformer {
6459 private val unrolledDefs = initializeUnrolledDefs()
6560
66- def computeIndices (annotated : Symbol )(using Context ): ComputedIndicies =
61+ def computeIndices (annotated : Symbol )(using Context ): ComputedIndices =
6762 unrolledDefs.getOrElseUpdate(annotated, {
6863 if annotated.name.is(DefaultGetterName ) then
6964 Nil // happens in curried methods where more than one parameter list has @unroll
@@ -84,25 +79,25 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
8479 end computeIndices
8580
8681 override def transform (tree : tpd.Tree )(using Context ): tpd.Tree = tree match
87- case tree @ TypeDef (_, impl : Template ) if classes (tree.symbol) =>
82+ case tree @ TypeDef (_, impl : Template ) if unrolledClasses (tree.symbol) =>
8883 super .transform(cpy.TypeDef (tree)(rhs = unrollTemplate(impl, computeIndices)))
8984 case tree =>
9085 super .transform(tree)
9186 }
9287
93- def copyParamSym (sym : Symbol , parent : Symbol )(using Context ): (Symbol , Symbol ) =
88+ private def copyParamSym (sym : Symbol , parent : Symbol )(using Context ): (Symbol , Symbol ) =
9489 val copied = sym.copy(owner = parent, flags = (sym.flags &~ HasDefault ), coord = sym.coord)
9590 sym -> copied
9691
97- def symLocation (sym : Symbol )(using Context ) = {
92+ private def symLocation (sym : Symbol )(using Context ) = {
9893 val lineDesc =
9994 if (sym.span.exists && sym.span != sym.owner.span)
10095 s " at line ${sym.srcPos.line + 1 }"
10196 else " "
10297 i " in ${sym.owner}${lineDesc}"
10398 }
10499
105- def findUnrollAnnotations (params : List [Symbol ])(using Context ): List [Int ] = {
100+ private def findUnrollAnnotations (params : List [Symbol ])(using Context ): List [Int ] = {
106101 params
107102 .zipWithIndex
108103 .collect {
@@ -111,16 +106,25 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
111106 }
112107 }
113108
114- def isTypeClause (p : ParamClause ) = p.headOption.exists(_.isInstanceOf [TypeDef ])
115-
116- def generateSingleForwarder (defdef : DefDef ,
117- prevMethodType : Type ,
109+ private def isTypeClause (p : ParamClause ) = p.headOption.exists(_.isInstanceOf [TypeDef ])
110+
111+ /** Generate a forwarder that calls the next one in a "chain" of forwarders
112+ *
113+ * @param defdef the original unrolled def that the forwarder is derived from
114+ * @param paramIndex index of the unrolled parameter (in the parameter list) that we stop at
115+ * @param paramCount number of parameters in the annotated parameter list
116+ * @param nextParamIndex index of next unrolled parameter - to fetch default argument
117+ * @param nextSpan span of next forwarder - used to ensure the span is not identical by shifting (TODO remove)
118+ * @param annotatedParamListIndex index of the parameter list that contains unrolled parameters
119+ * @param isCaseApply if `defdef` is a case class apply/constructor - used for selection of default arguments
120+ */
121+ private def generateSingleForwarder (defdef : DefDef ,
118122 paramIndex : Int ,
119123 paramCount : Int ,
120124 nextParamIndex : Int ,
121- nextSymbol : Symbol ,
125+ nextSpan : Span ,
122126 annotatedParamListIndex : Int ,
123- isCaseApply : Boolean )(using Context ) = {
127+ isCaseApply : Boolean )(using Context ): DefDef = {
124128
125129 def initNewForwarder ()(using Context ): (TermSymbol , List [List [Symbol ]]) = {
126130 val forwarderDefSymbol0 = Symbols .newSymbol(
@@ -129,15 +133,15 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
129133 defdef.symbol.flags &~ HasDefaultParams |
130134 Invisible | Synthetic ,
131135 NoType , // fill in later
132- coord = nextSymbol.span .shift(1 ) // shift by 1 to avoid "secondary constructor must call preceding" error
136+ coord = nextSpan .shift(1 ) // shift by 1 to avoid "secondary constructor must call preceding" error
133137 ).entered
134138
135139 val newParamSymMappings = extractParamSymss(copyParamSym(_, forwarderDefSymbol0))
136140 val (oldParams, newParams) = newParamSymMappings.flatten.unzip
137141
138142 val newParamSymLists0 =
139- newParamSymMappings.map: pairss =>
140- pairss .map: (oldSym, newSym) =>
143+ newParamSymMappings.map: pairs =>
144+ pairs .map: (oldSym, newSym) =>
141145 newSym.info = oldSym.info.substSym(oldParams, newParams)
142146 newSym
143147
@@ -153,8 +157,6 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
153157 else ps.map(p => onSymbol(p.symbol))
154158 }
155159
156- val paramCount = defdef.symbol.paramSymss(annotatedParamListIndex).size
157-
158160 val (forwarderDefSymbol, newParamSymLists) = initNewForwarder()
159161
160162 def forwarderRhs (): tpd.Tree = {
@@ -224,10 +226,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
224226 val forwarderDef =
225227 tpd.DefDef (forwarderDefSymbol, rhs = forwarderRhs())
226228
227- forwarderDef.withSpan(nextSymbol.span .shift(1 ))
229+ forwarderDef.withSpan(nextSpan .shift(1 ))
228230 }
229231
230- def generateFromProduct (startParamIndices : List [Int ], paramCount : Int , defdef : DefDef )(using Context ) = {
232+ private def generateFromProduct (startParamIndices : List [Int ], paramCount : Int , defdef : DefDef )(using Context ) = {
231233 cpy.DefDef (defdef)(
232234 name = defdef.name,
233235 paramss = defdef.paramss,
@@ -248,28 +250,35 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
248250 )
249251 )
250252 )
251- } ++ Seq (
252- CaseDef (
253- Underscore (defn.IntType ),
254- EmptyTree ,
255- defdef.rhs
256- )
253+ } :+ CaseDef (
254+ Underscore (defn.IntType ),
255+ EmptyTree ,
256+ defdef.rhs
257257 )
258258 )
259259 ).setDefTree
260260 }
261261
262- def generateSyntheticDefs (tree : Tree , compute : ComputeIndicies )(using Context ): Option [(Symbol , Option [Symbol ], Seq [DefDef ])] = tree match {
262+ private enum Gen :
263+ case Substitute (origin : Symbol , newDef : DefDef )
264+ case Forwarders (origin : Symbol , forwarders : Seq [DefDef ])
265+
266+ def origin : Symbol
267+ def extras : Seq [DefDef ] = this match
268+ case Substitute (_, d) => d :: Nil
269+ case Forwarders (_, ds) => ds
270+
271+ private def generateSyntheticDefs (tree : Tree , compute : ComputeIndices )(using Context ): Option [Gen ] = tree match {
263272 case defdef : DefDef if defdef.paramss.nonEmpty =>
264273 import dotty .tools .dotc .core .NameOps .isConstructorName
265274
266275 val isCaseCopy =
267- defdef.name.toString == " copy" && defdef.symbol.owner.is(CaseClass )
276+ defdef.name == nme. copy && defdef.symbol.owner.is(CaseClass )
268277
269278 val isCaseApply =
270- defdef.name.toString == " apply" && defdef.symbol.owner.companionClass.is(CaseClass )
279+ defdef.name == nme. apply && defdef.symbol.owner.companionClass.is(CaseClass )
271280
272- val isCaseFromProduct = defdef.name.toString == " fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass )
281+ val isCaseFromProduct = defdef.name == nme. fromProduct && defdef.symbol.owner.companionClass.is(CaseClass )
273282
274283 val annotated =
275284 if (isCaseCopy) defdef.symbol.owner.primaryConstructor
@@ -282,25 +291,27 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
282291 case Seq ((paramClauseIndex, annotationIndices)) =>
283292 val paramCount = annotated.paramSymss(paramClauseIndex).size
284293 if isCaseFromProduct then
285- Some ((defdef.symbol, Some (defdef.symbol), Seq (generateFromProduct(annotationIndices, paramCount, defdef))))
294+ Some (Gen .Substitute (
295+ origin = defdef.symbol,
296+ newDef = generateFromProduct(annotationIndices, paramCount, defdef)
297+ ))
286298 else
287299 val (generatedDefs, _) =
288300 val indices = (annotationIndices :+ paramCount).sliding(2 ).toList.reverse
289- indices.foldLeft((Seq .empty[DefDef ], defdef.symbol)):
290- case ((defdefs, nextSymbol ), Seq (paramIndex, nextParamIndex)) =>
301+ indices.foldLeft((Seq .empty[DefDef ], defdef.symbol.span )):
302+ case ((defdefs, nextSpan ), Seq (paramIndex, nextParamIndex)) =>
291303 val forwarder = generateSingleForwarder(
292304 defdef,
293- defdef.symbol.info,
294305 paramIndex,
295306 paramCount,
296307 nextParamIndex,
297- nextSymbol ,
308+ nextSpan ,
298309 paramClauseIndex,
299310 isCaseApply
300311 )
301- (forwarder +: defdefs, forwarder.symbol)
312+ (forwarder +: defdefs, forwarder.symbol.span )
302313 case _ => unreachable(" sliding with at least 2 elements" )
303- Some (( defdef.symbol, None , generatedDefs))
314+ Some (Gen . Forwarders (origin = defdef.symbol, forwarders = generatedDefs))
304315
305316 case multiple =>
306317 report.error(" Cannot have multiple parameter lists containing `@unroll` annotation" , defdef.srcPos)
@@ -310,12 +321,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
310321 case _ => None
311322 }
312323
313- def unrollTemplate (tmpl : tpd.Template , compute : ComputeIndicies )(using Context ): tpd.Tree = {
324+ private def unrollTemplate (tmpl : tpd.Template , compute : ComputeIndices )(using Context ): tpd.Tree = {
314325
315326 val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
316327 val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
317328 val allGenerated = generatedBody ++ generatedConstr0
318- val bodySubs = generatedBody.flatMap((_, maybeSub, _) => maybeSub ).toSet
329+ val bodySubs = generatedBody.collect({ case s : Gen . Substitute => s.origin } ).toSet
319330 val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))
320331
321332 /** inlined from compiler/src/dotty/tools/dotc/typer/Checking.scala */
@@ -326,28 +337,26 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
326337 if allGenerated.nonEmpty then
327338 val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
328339 for
329- (src, _, dcls) <- allGenerated
330- dcl <- dcls
340+ syntheticDefs <- allGenerated
341+ dcl <- syntheticDefs.extras
331342 do
332343 val replaced = dcl.symbol
333344 byName.get(dcl.name.toString).foreach { syms =>
334345 val clashes = syms.filter(checkClash(replaced, _))
335346 for existing <- clashes do
347+ val src = syntheticDefs.origin
336348 report.error(i """ Unrolled $replaced clashes with existing declaration.
337349 |Please remove the clashing definition, or the @unroll annotation.
338350 |Unrolled from ${hl(src.showDcl)} ${symLocation(src)}""" .stripMargin, existing.srcPos)
339351 }
340352 end if
341353
342- val generatedDefs = generatedBody.flatMap((_, _, gens) => gens)
343- val generatedConstr = generatedConstr0.toList.flatMap((_, _, gens) => gens)
344-
345354 cpy.Template (tmpl)(
346355 tmpl.constr,
347356 tmpl.parents,
348357 tmpl.derived,
349358 tmpl.self,
350- otherDecls ++ generatedDefs ++ generatedConstr
359+ otherDecls ++ allGenerated.flatMap(_.extras)
351360 )
352361 }
353362
0 commit comments