Skip to content

Commit 77704dd

Browse files
committed
address review: part 1
1 parent 8d3db3e commit 77704dd

File tree

10 files changed

+89
-92
lines changed

10 files changed

+89
-92
lines changed

compiler/src/dotty/tools/dotc/CompilationUnit.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class CompilationUnit protected (val source: SourceFile, val info: CompilationUn
5959

6060
var hasMacroAnnotations: Boolean = false
6161

62-
def hasUnrollDefs: Boolean = unrolledClasses != null
63-
var unrolledClasses: Set[Symbol] | Null = null
62+
def hasUnrollDefs: Boolean = unrolledClasses.nonEmpty
63+
var unrolledClasses: Set[Symbol] = Set.empty
6464

6565
/** Set to `true` if inliner added anonymous mirrors that need to be completed */
6666
var needsMirrorSupport: Boolean = false

compiler/src/dotty/tools/dotc/core/Denotations.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,9 @@ object Denotations {
10741074
def filterDisjoint(denots: PreDenotation)(using Context): SingleDenotation =
10751075
if (denots.exists && denots.matches(this)) NoDenotation else this
10761076
def filterWithFlags(required: FlagSet, excluded: FlagSet)(using Context): SingleDenotation =
1077-
val realExcluded = if ctx.isAfterTyper then excluded else excluded | (if ctx.mode.is(Mode.ResolveFromTASTy) then EmptyFlags else Invisible)
1077+
val realExcluded =
1078+
if ctx.isAfterTyper || ctx.mode.is(Mode.ResolveFromTASTy) then excluded
1079+
else excluded | Invisible
10781080
def symd: SymDenotation = this match
10791081
case symd: SymDenotation => symd
10801082
case _ => symbol.denot

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import collection.mutable
2020
import reporting.{Profile, NoProfile}
2121
import dotty.tools.tasty.TastyFormat.ASTsSection
2222
import quoted.QuotePatterns
23-
import dotty.tools.dotc.config.Feature
2423

2524
object TreePickler:
2625
class StackSizeExceeded(val mdef: tpd.MemberDef) extends Exception
@@ -475,16 +474,15 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
475474
case _ =>
476475
if passesConditionForErroringBestEffortCode(tree.hasType) then
477476
// #19951 The signature of a constructor of a Java annotation is irrelevant
478-
val sym = tree.symbol
479477
val sig =
480-
if name == nme.CONSTRUCTOR && sym.exists && sym.owner.is(JavaAnnotation) then Signature.NotAMethod
478+
if name == nme.CONSTRUCTOR && tree.symbol.exists && tree.symbol.owner.is(JavaAnnotation) then Signature.NotAMethod
481479
else tree.tpe.signature
482-
var ename = sym.targetName
480+
var ename = tree.symbol.targetName
483481
val selectFromQualifier =
484482
name.isTypeName
485483
|| qual.isInstanceOf[Hole] // holes have no symbol
486484
|| sig == Signature.NotAMethod // no overload resolution necessary
487-
|| !sym.exists // polymorphic function type
485+
|| !tree.denot.symbol.exists // polymorphic function type
488486
|| tree.denot.asSingleDenotation.isRefinedMethod // refined methods have no defining class symbol
489487
if selectFromQualifier then
490488
writeByte(if name.isTypeName then SELECTtpt else SELECT)
@@ -493,9 +491,9 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
493491
else // select from owner
494492
writeByte(SELECTin)
495493
withLength {
496-
pickleNameAndSig(name, sym.signature, ename)
494+
pickleNameAndSig(name, tree.symbol.signature, ename)
497495
pickleTree(qual)
498-
pickleType(sym.owner.typeRef)
496+
pickleType(tree.symbol.owner.typeRef)
499497
}
500498
else
501499
writeByte(if name.isTypeName then SELECTtpt else SELECT)

compiler/src/dotty/tools/dotc/reporting/messages.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3361,7 +3361,7 @@ extends DeclarationMsg(IllegalUnrollPlacementID):
33613361
case Some(method) =>
33623362
val isCtor = method.isConstructor
33633363
def what = if isCtor then i"a ${if method.owner.is(Trait) then "trait" else "class"} constructor" else i"method ${method.name}"
3364-
val prefix = s"Can not unroll parameters of $what"
3364+
val prefix = s"Cannot unroll parameters of $what"
33653365
if method.is(Deferred) then
33663366
i"$prefix: it must not be abstract"
33673367
else if isCtor && method.owner.is(Trait) then

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,12 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
120120

121121
private var inJavaAnnot: Boolean = false
122122

123-
private var seenUnrolledMethods: util.EqHashMap[Symbol, Boolean] | Null = null
123+
private val seenUnrolledMethods: util.EqHashMap[Symbol, Boolean] = new util.EqHashMap[Symbol, Boolean]
124124

125125
private var noCheckNews: Set[New] = Set()
126126

127127
def isValidUnrolledMethod(method: Symbol, origin: SrcPos)(using Context): Boolean =
128-
val seenMethods =
129-
val local = seenUnrolledMethods
130-
if local == null then
131-
val map = new util.EqHashMap[Symbol, Boolean]
132-
seenUnrolledMethods = map
133-
map
134-
else
135-
local
136-
seenMethods.getOrElseUpdate(method, {
128+
seenUnrolledMethods.getOrElseUpdate(method, {
137129
val isCtor = method.isConstructor
138130
if
139131
method.name.is(DefaultGetterName)
@@ -208,12 +200,8 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
208200
private def registerIfUnrolledParam(sym: Symbol)(using Context): Unit =
209201
if sym.hasAnnotation(defn.UnrollAnnot) && isValidUnrolledMethod(sym.owner, sym.sourcePos) then
210202
val cls = sym.enclosingClass
211-
val classes = ctx.compilationUnit.unrolledClasses
212203
val additions = Array(cls, cls.linkedClass).filter(_ != NoSymbol)
213-
if classes == null then
214-
ctx.compilationUnit.unrolledClasses = Set.from(additions)
215-
else
216-
ctx.compilationUnit.unrolledClasses = classes ++ additions
204+
ctx.compilationUnit.unrolledClasses ++= additions
217205

218206
private def processValOrDefDef(tree: Tree)(using Context): tree.type =
219207
val sym = tree.symbol

compiler/src/dotty/tools/dotc/transform/UnrollDefinitions.scala

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222
import scala.util.boundary, boundary.break
2323
import dotty.tools.dotc.core.StdNames.nme
2424
import dotty.tools.unreachable
25+
import dotty.tools.dotc.util.Spans.Span
2526

2627
/**Implementation of SIP-61.
2728
* Runs when `@unroll` annotations are found in a compilation unit, installing new definitions
@@ -33,16 +34,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
3334

3435
import tpd.*
3536

36-
private var _unrolledDefs: util.EqHashMap[Symbol, ComputedIndicies] | Null = null
37-
private def initializeUnrolledDefs(): util.EqHashMap[Symbol, ComputedIndicies] =
38-
val local = _unrolledDefs
39-
if local == null then
40-
val map = new util.EqHashMap[Symbol, ComputedIndicies]
41-
_unrolledDefs = map
42-
map
43-
else
44-
local.clear()
45-
local
37+
private val _unrolledDefs: util.EqHashMap[Symbol, ComputedIndices] = new util.EqHashMap[Symbol, ComputedIndices]
38+
private def initializeUnrolledDefs(): util.EqHashMap[Symbol, ComputedIndices] =
39+
_unrolledDefs.clear()
40+
_unrolledDefs
4641

4742
override def phaseName: String = UnrollDefinitions.name
4843

@@ -55,15 +50,15 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
5550
super.run // create and run the transformer on the current compilation unit
5651

5752
def newTransformer(using Context): Transformer =
58-
UnrollingTransformer(ctx.compilationUnit.unrolledClasses.nn)
53+
UnrollingTransformer(ctx.compilationUnit.unrolledClasses)
5954

60-
type ComputedIndicies = List[(Int, List[Int])]
61-
type ComputeIndicies = Context ?=> Symbol => ComputedIndicies
55+
type ComputedIndices = List[(Int, List[Int])]
56+
type ComputeIndices = Context ?=> Symbol => ComputedIndices
6257

63-
private class UnrollingTransformer(classes: Set[Symbol]) extends Transformer {
58+
private class UnrollingTransformer(unrolledClasses: Set[Symbol]) extends Transformer {
6459
private val unrolledDefs = initializeUnrolledDefs()
6560

66-
def computeIndices(annotated: Symbol)(using Context): ComputedIndicies =
61+
def computeIndices(annotated: Symbol)(using Context): ComputedIndices =
6762
unrolledDefs.getOrElseUpdate(annotated, {
6863
if annotated.name.is(DefaultGetterName) then
6964
Nil // happens in curried methods where more than one parameter list has @unroll
@@ -84,25 +79,25 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
8479
end computeIndices
8580

8681
override def transform(tree: tpd.Tree)(using Context): tpd.Tree = tree match
87-
case tree @ TypeDef(_, impl: Template) if classes(tree.symbol) =>
82+
case tree @ TypeDef(_, impl: Template) if unrolledClasses(tree.symbol) =>
8883
super.transform(cpy.TypeDef(tree)(rhs = unrollTemplate(impl, computeIndices)))
8984
case tree =>
9085
super.transform(tree)
9186
}
9287

93-
def copyParamSym(sym: Symbol, parent: Symbol)(using Context): (Symbol, Symbol) =
88+
private def copyParamSym(sym: Symbol, parent: Symbol)(using Context): (Symbol, Symbol) =
9489
val copied = sym.copy(owner = parent, flags = (sym.flags &~ HasDefault), coord = sym.coord)
9590
sym -> copied
9691

97-
def symLocation(sym: Symbol)(using Context) = {
92+
private def symLocation(sym: Symbol)(using Context) = {
9893
val lineDesc =
9994
if (sym.span.exists && sym.span != sym.owner.span)
10095
s" at line ${sym.srcPos.line + 1}"
10196
else ""
10297
i"in ${sym.owner}${lineDesc}"
10398
}
10499

105-
def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
100+
private def findUnrollAnnotations(params: List[Symbol])(using Context): List[Int] = {
106101
params
107102
.zipWithIndex
108103
.collect {
@@ -111,16 +106,25 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
111106
}
112107
}
113108

114-
def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
115-
116-
def generateSingleForwarder(defdef: DefDef,
117-
prevMethodType: Type,
109+
private def isTypeClause(p: ParamClause) = p.headOption.exists(_.isInstanceOf[TypeDef])
110+
111+
/** Generate a forwarder that calls the next one in a "chain" of forwarders
112+
*
113+
* @param defdef the original unrolled def that the forwarder is derived from
114+
* @param paramIndex index of the unrolled parameter (in the parameter list) that we stop at
115+
* @param paramCount number of parameters in the annotated parameter list
116+
* @param nextParamIndex index of next unrolled parameter - to fetch default argument
117+
* @param nextSpan span of next forwarder - used to ensure the span is not identical by shifting (TODO remove)
118+
* @param annotatedParamListIndex index of the parameter list that contains unrolled parameters
119+
* @param isCaseApply if `defdef` is a case class apply/constructor - used for selection of default arguments
120+
*/
121+
private def generateSingleForwarder(defdef: DefDef,
118122
paramIndex: Int,
119123
paramCount: Int,
120124
nextParamIndex: Int,
121-
nextSymbol: Symbol,
125+
nextSpan: Span,
122126
annotatedParamListIndex: Int,
123-
isCaseApply: Boolean)(using Context) = {
127+
isCaseApply: Boolean)(using Context): DefDef = {
124128

125129
def initNewForwarder()(using Context): (TermSymbol, List[List[Symbol]]) = {
126130
val forwarderDefSymbol0 = Symbols.newSymbol(
@@ -129,15 +133,15 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
129133
defdef.symbol.flags &~ HasDefaultParams |
130134
Invisible | Synthetic,
131135
NoType, // fill in later
132-
coord = nextSymbol.span.shift(1) // shift by 1 to avoid "secondary constructor must call preceding" error
136+
coord = nextSpan.shift(1) // shift by 1 to avoid "secondary constructor must call preceding" error
133137
).entered
134138

135139
val newParamSymMappings = extractParamSymss(copyParamSym(_, forwarderDefSymbol0))
136140
val (oldParams, newParams) = newParamSymMappings.flatten.unzip
137141

138142
val newParamSymLists0 =
139-
newParamSymMappings.map: pairss =>
140-
pairss.map: (oldSym, newSym) =>
143+
newParamSymMappings.map: pairs =>
144+
pairs.map: (oldSym, newSym) =>
141145
newSym.info = oldSym.info.substSym(oldParams, newParams)
142146
newSym
143147

@@ -153,8 +157,6 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
153157
else ps.map(p => onSymbol(p.symbol))
154158
}
155159

156-
val paramCount = defdef.symbol.paramSymss(annotatedParamListIndex).size
157-
158160
val (forwarderDefSymbol, newParamSymLists) = initNewForwarder()
159161

160162
def forwarderRhs(): tpd.Tree = {
@@ -224,10 +226,10 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
224226
val forwarderDef =
225227
tpd.DefDef(forwarderDefSymbol, rhs = forwarderRhs())
226228

227-
forwarderDef.withSpan(nextSymbol.span.shift(1))
229+
forwarderDef.withSpan(nextSpan.shift(1))
228230
}
229231

230-
def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
232+
private def generateFromProduct(startParamIndices: List[Int], paramCount: Int, defdef: DefDef)(using Context) = {
231233
cpy.DefDef(defdef)(
232234
name = defdef.name,
233235
paramss = defdef.paramss,
@@ -248,28 +250,35 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
248250
)
249251
)
250252
)
251-
} ++ Seq(
252-
CaseDef(
253-
Underscore(defn.IntType),
254-
EmptyTree,
255-
defdef.rhs
256-
)
253+
} :+ CaseDef(
254+
Underscore(defn.IntType),
255+
EmptyTree,
256+
defdef.rhs
257257
)
258258
)
259259
).setDefTree
260260
}
261261

262-
def generateSyntheticDefs(tree: Tree, compute: ComputeIndicies)(using Context): Option[(Symbol, Option[Symbol], Seq[DefDef])] = tree match {
262+
private enum Gen:
263+
case Substitute(origin: Symbol, newDef: DefDef)
264+
case Forwarders(origin: Symbol, forwarders: Seq[DefDef])
265+
266+
def origin: Symbol
267+
def extras: Seq[DefDef] = this match
268+
case Substitute(_, d) => d :: Nil
269+
case Forwarders(_, ds) => ds
270+
271+
private def generateSyntheticDefs(tree: Tree, compute: ComputeIndices)(using Context): Option[Gen] = tree match {
263272
case defdef: DefDef if defdef.paramss.nonEmpty =>
264273
import dotty.tools.dotc.core.NameOps.isConstructorName
265274

266275
val isCaseCopy =
267-
defdef.name.toString == "copy" && defdef.symbol.owner.is(CaseClass)
276+
defdef.name == nme.copy && defdef.symbol.owner.is(CaseClass)
268277

269278
val isCaseApply =
270-
defdef.name.toString == "apply" && defdef.symbol.owner.companionClass.is(CaseClass)
279+
defdef.name == nme.apply && defdef.symbol.owner.companionClass.is(CaseClass)
271280

272-
val isCaseFromProduct = defdef.name.toString == "fromProduct" && defdef.symbol.owner.companionClass.is(CaseClass)
281+
val isCaseFromProduct = defdef.name == nme.fromProduct && defdef.symbol.owner.companionClass.is(CaseClass)
273282

274283
val annotated =
275284
if (isCaseCopy) defdef.symbol.owner.primaryConstructor
@@ -282,25 +291,27 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
282291
case Seq((paramClauseIndex, annotationIndices)) =>
283292
val paramCount = annotated.paramSymss(paramClauseIndex).size
284293
if isCaseFromProduct then
285-
Some((defdef.symbol, Some(defdef.symbol), Seq(generateFromProduct(annotationIndices, paramCount, defdef))))
294+
Some(Gen.Substitute(
295+
origin = defdef.symbol,
296+
newDef = generateFromProduct(annotationIndices, paramCount, defdef)
297+
))
286298
else
287299
val (generatedDefs, _) =
288300
val indices = (annotationIndices :+ paramCount).sliding(2).toList.reverse
289-
indices.foldLeft((Seq.empty[DefDef], defdef.symbol)):
290-
case ((defdefs, nextSymbol), Seq(paramIndex, nextParamIndex)) =>
301+
indices.foldLeft((Seq.empty[DefDef], defdef.symbol.span)):
302+
case ((defdefs, nextSpan), Seq(paramIndex, nextParamIndex)) =>
291303
val forwarder = generateSingleForwarder(
292304
defdef,
293-
defdef.symbol.info,
294305
paramIndex,
295306
paramCount,
296307
nextParamIndex,
297-
nextSymbol,
308+
nextSpan,
298309
paramClauseIndex,
299310
isCaseApply
300311
)
301-
(forwarder +: defdefs, forwarder.symbol)
312+
(forwarder +: defdefs, forwarder.symbol.span)
302313
case _ => unreachable("sliding with at least 2 elements")
303-
Some((defdef.symbol, None, generatedDefs))
314+
Some(Gen.Forwarders(origin = defdef.symbol, forwarders = generatedDefs))
304315

305316
case multiple =>
306317
report.error("Cannot have multiple parameter lists containing `@unroll` annotation", defdef.srcPos)
@@ -310,12 +321,12 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
310321
case _ => None
311322
}
312323

313-
def unrollTemplate(tmpl: tpd.Template, compute: ComputeIndicies)(using Context): tpd.Tree = {
324+
private def unrollTemplate(tmpl: tpd.Template, compute: ComputeIndices)(using Context): tpd.Tree = {
314325

315326
val generatedBody = tmpl.body.flatMap(generateSyntheticDefs(_, compute))
316327
val generatedConstr0 = generateSyntheticDefs(tmpl.constr, compute)
317328
val allGenerated = generatedBody ++ generatedConstr0
318-
val bodySubs = generatedBody.flatMap((_, maybeSub, _) => maybeSub).toSet
329+
val bodySubs = generatedBody.collect({ case s: Gen.Substitute => s.origin }).toSet
319330
val otherDecls = tmpl.body.filterNot(d => d.symbol.exists && bodySubs(d.symbol))
320331

321332
/** inlined from compiler/src/dotty/tools/dotc/typer/Checking.scala */
@@ -326,28 +337,26 @@ class UnrollDefinitions extends MacroTransform, IdentityDenotTransformer {
326337
if allGenerated.nonEmpty then
327338
val byName = (tmpl.constr :: otherDecls).groupMap(_.symbol.name.toString)(_.symbol)
328339
for
329-
(src, _, dcls) <- allGenerated
330-
dcl <- dcls
340+
syntheticDefs <- allGenerated
341+
dcl <- syntheticDefs.extras
331342
do
332343
val replaced = dcl.symbol
333344
byName.get(dcl.name.toString).foreach { syms =>
334345
val clashes = syms.filter(checkClash(replaced, _))
335346
for existing <- clashes do
347+
val src = syntheticDefs.origin
336348
report.error(i"""Unrolled $replaced clashes with existing declaration.
337349
|Please remove the clashing definition, or the @unroll annotation.
338350
|Unrolled from ${hl(src.showDcl)} ${symLocation(src)}""".stripMargin, existing.srcPos)
339351
}
340352
end if
341353

342-
val generatedDefs = generatedBody.flatMap((_, _, gens) => gens)
343-
val generatedConstr = generatedConstr0.toList.flatMap((_, _, gens) => gens)
344-
345354
cpy.Template(tmpl)(
346355
tmpl.constr,
347356
tmpl.parents,
348357
tmpl.derived,
349358
tmpl.self,
350-
otherDecls ++ generatedDefs ++ generatedConstr
359+
otherDecls ++ allGenerated.flatMap(_.extras)
351360
)
352361
}
353362

0 commit comments

Comments
 (0)