Skip to content

Commit 80d9dcc

Browse files
authored
Merge pull request #9549 from dotty-staging/topic/enum-value-custom-tostring
fix #7227: allow custom toString on enum
2 parents 6e6f67b + 80a4d19 commit 80d9dcc

File tree

16 files changed

+199
-31
lines changed

16 files changed

+199
-31
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,9 @@ object desugar {
586586
yield syntheticProperty(selName, caseParams(i).tpt,
587587
Select(This(EmptyTypeIdent), caseParams(i).name))
588588

589-
def ordinalMeths = if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: Nil else Nil
589+
def enumMeths =
590+
if (isEnumCase) ordinalMethLit(nextOrdinal(CaseKind.Class)._1) :: enumLabelLit(className.toString) :: Nil
591+
else Nil
590592
def copyMeths = {
591593
val hasRepeatedParam = constrVparamss.exists(_.exists {
592594
case ValDef(_, tpt, _) => isRepeated(tpt)
@@ -605,7 +607,7 @@ object desugar {
605607
}
606608

607609
if (isCaseClass)
608-
copyMeths ::: ordinalMeths ::: productElemMeths
610+
copyMeths ::: enumMeths ::: productElemMeths
609611
else Nil
610612
}
611613

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -125,18 +125,21 @@ object DesugarEnums {
125125
/** A creation method for a value of enum type `E`, which is defined as follows:
126126
*
127127
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
128-
* def ordinal = _$ordinal // if `E` does not derive from jl.Enum
129-
* override def toString = $name // if `E` does not derive from jl.Enum
128+
* def ordinal = _$ordinal // if `E` does not derive from `java.lang.Enum`
129+
* def enumLabel = $name // if `E` does not derive from `java.lang.Enum`
130+
* def enumLabel = this.name // if `E` derives from `java.lang.Enum`
130131
* $values.register(this)
131132
* }
132133
*/
133134
private def enumValueCreator(using Context) = {
134135
val fieldMethods =
135-
if isJavaEnum then Nil
136-
else
137-
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
138-
val toStringDef = toStringMeth(Ident(nme.nameDollar))
139-
List(ordinalDef, toStringDef)
136+
if isJavaEnum then
137+
val enumLabelDef = enumLabelMeth(Select(This(Ident(tpnme.EMPTY)), nme.name))
138+
enumLabelDef :: Nil
139+
else
140+
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
141+
val enumLabelDef = enumLabelMeth(Ident(nme.nameDollar))
142+
ordinalDef :: enumLabelDef :: Nil
140143
val creator = New(Template(
141144
constr = emptyConstructor,
142145
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
@@ -273,14 +276,14 @@ object DesugarEnums {
273276
def ordinalMeth(body: Tree)(using Context): DefDef =
274277
DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body)
275278

276-
def toStringMeth(body: Tree)(using Context): DefDef =
277-
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
279+
def enumLabelMeth(body: Tree)(using Context): DefDef =
280+
DefDef(nme.enumLabel, Nil, Nil, TypeTree(defn.StringType), body)
278281

279282
def ordinalMethLit(ord: Int)(using Context): DefDef =
280283
ordinalMeth(Literal(Constant(ord)))
281284

282-
def toStringMethLit(name: String)(using Context): DefDef =
283-
toStringMeth(Literal(Constant(name)))
285+
def enumLabelLit(name: String)(using Context): DefDef =
286+
enumLabelMeth(Literal(Constant(name)))
284287

285288
/** Expand a module definition representing a parameterless enum case */
286289
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
@@ -290,16 +293,12 @@ object DesugarEnums {
290293
expandSimpleEnumCase(name, mods, span)
291294
else {
292295
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
293-
val fieldMethods =
294-
if isJavaEnum then Nil
295-
else
296-
val ordinalDef = ordinalMethLit(tag)
297-
val toStringDef = toStringMethLit(name.toString)
298-
List(ordinalDef, toStringDef)
296+
val ordinalDef = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
297+
val enumLabelDef = enumLabelLit(name.toString)
299298
val impl1 = cpy.Template(impl)(
300299
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
301-
body = fieldMethods ::: registerCall :: Nil)
302-
.withAttachment(ExtendsSingletonMirror, ())
300+
body = ordinalDef ::: enumLabelDef :: registerCall :: Nil
301+
).withAttachment(ExtendsSingletonMirror, ())
303302
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
304303
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
305304
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,6 @@ class Definitions {
648648
@tu lazy val NoneModule: Symbol = requiredModule("scala.None")
649649

650650
@tu lazy val EnumClass: ClassSymbol = requiredClass("scala.Enum")
651-
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
652651

653652
@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")
654653

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,7 @@ object StdNames {
453453
val emptyValDef: N = "emptyValDef"
454454
val end: N = "end"
455455
val ensureAccessible : N = "ensureAccessible"
456+
val enumLabel: N = "enumLabel"
456457
val eq: N = "eq"
457458
val eqInstance: N = "eqInstance"
458459
val equalsNumChar : N = "equalsNumChar"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
5757
private var myValueSymbols: List[Symbol] = Nil
5858
private var myCaseSymbols: List[Symbol] = Nil
5959
private var myCaseModuleSymbols: List[Symbol] = Nil
60+
private var myEnumValueSymbols: List[Symbol] = Nil
61+
private var myNonJavaEnumValueSymbols: List[Symbol] = Nil
6062

6163
private def initSymbols(using Context) =
6264
if (myValueSymbols.isEmpty) {
@@ -65,11 +67,15 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
6567
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement,
6668
defn.Product_productElementName)
6769
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
70+
myEnumValueSymbols = List(defn.Product_productPrefix)
71+
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString
6872
}
6973

7074
def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
7175
def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols }
7276
def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
77+
def enumValueSymbols(using Context): List[Symbol] = { initSymbols; myEnumValueSymbols }
78+
def nonJavaEnumValueSymbols(using Context): List[Symbol] = { initSymbols; myNonJavaEnumValueSymbols }
7379

7480
private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = {
7581
val existing = sym.matchingMember(clazz.thisType)
@@ -89,11 +95,15 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
8995
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
9096
else clazz.caseAccessors
9197
val isEnumCase = clazz.derivesFrom(defn.EnumClass) && clazz != defn.EnumClass
98+
val isEnumValue = isEnumCase && clazz.isAnonymousClass && clazz.classParents.head.classSymbol.is(Enum)
99+
val isNonJavaEnumValue = isEnumValue && !clazz.derivesFrom(defn.JavaEnumClass)
92100

93101
val symbolsToSynthesize: List[Symbol] =
94102
if (clazz.is(Case))
95103
if (clazz.is(Module)) caseModuleSymbols
96104
else caseSymbols
105+
else if (isNonJavaEnumValue) nonJavaEnumValueSymbols
106+
else if (isEnumValue) enumValueSymbols
97107
else if (isDerivedValueClass(clazz)) valueSymbols
98108
else Nil
99109

@@ -113,13 +123,22 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
113123
def ownName: Tree =
114124
Literal(Constant(clazz.name.stripModuleClassSuffix.toString))
115125

126+
def callEnumLabel: Tree =
127+
Select(This(clazz), nme.enumLabel).ensureApplied
128+
129+
def toStringBody(vrefss: List[List[Tree]]): Tree =
130+
if (clazz.is(ModuleClass)) ownName
131+
else if (isNonJavaEnumValue) callEnumLabel
132+
else forwardToRuntime(vrefss.head)
133+
116134
def syntheticRHS(vrefss: List[List[Tree]])(using Context): Tree = synthetic.name match {
117135
case nme.hashCode_ if isDerivedValueClass(clazz) => valueHashCodeBody
118136
case nme.hashCode_ => chooseHashcode
119-
case nme.toString_ => if (clazz.is(ModuleClass)) ownName else forwardToRuntime(vrefss.head)
137+
case nme.toString_ => toStringBody(vrefss)
120138
case nme.equals_ => equalsBody(vrefss.head.head)
121139
case nme.canEqual_ => canEqualBody(vrefss.head.head)
122140
case nme.productArity => Literal(Constant(accessors.length))
141+
case nme.productPrefix if isEnumValue => callEnumLabel
123142
case nme.productPrefix => ownName
124143
case nme.productElement => productElementBody(accessors.length, vrefss.head.head)
125144
case nme.productElementName => productElementNameBody(accessors.length, vrefss.head.head)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,7 @@ trait Checking {
11301130

11311131
end checkEnumParent
11321132

1133+
11331134
/** Check that all references coming from enum cases in an enum companion object
11341135
* are legal.
11351136
* @param cdef the enum companion object class

docs/docs/reference/enums/desugarEnums.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,19 @@ If `E` contains at least one simple case, its companion object will define in ad
174174
follows.
175175
```scala
176176
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
177-
def ordinal = _$ordinal // if `E` does not have `java.lang.Enum` as a parent
178-
override def toString = $name // if `E` does not have `java.lang.Enum` as a parent
177+
def ordinal = _$ordinal
178+
def enumLabel = $name
179+
override def productPrefix = enumLabel // if not overridden in `E`
180+
override def toString = enumLabel // if not overridden in `E`
179181
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
180182
}
181183
```
182184

183185
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
184-
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
185-
`java.lang.Enum`.
186+
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend
187+
`java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as
188+
`java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
189+
`java.lang.Enum`. Finally, `enumLabel` will call `this.name` when `E` extends `java.lang.Enum`.
186190

187191
### Scopes for Enum Cases
188192

docs/docs/reference/enums/enums.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,17 @@ For a more in-depth example of using Scala 3 enums from Java, see [this test](ht
110110
### Implementation
111111

112112
Enums are represented as `sealed` classes that extend the `scala.Enum` trait.
113-
This trait defines a single public method, `ordinal`:
113+
This trait defines two public methods, `ordinal` and `enumLabel`:
114114

115115
```scala
116116
package scala
117117

118118
/** A base trait of all enum classes */
119119
trait Enum extends Product with Serializable {
120120

121+
/** A string uniquely identifying a case of an enum */
122+
def enumLabel: String
123+
121124
/** A number uniquely identifying a case of an enum */
122125
def ordinal: Int
123126
}
@@ -130,7 +133,9 @@ For instance, the `Venus` value above would be defined like this:
130133
val Venus: Planet =
131134
new Planet(4.869E24, 6051800.0) {
132135
def ordinal: Int = 1
133-
override def toString: String = "Venus"
136+
def enumLabel: String = "Venus"
137+
override def productPrefix: String = enumLabel
138+
override def toString: String = enumLabel
134139
// internal code to register value
135140
}
136141
```

library/src-bootstrapped/scala/Enum.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ package scala
33
/** A base trait of all enum classes */
44
trait Enum extends Product, Serializable:
55

6+
/** A string uniquely identifying a case of an enum */
7+
def enumLabel: String
8+
69
/** A number uniquely identifying a case of an enum */
710
def ordinal: Int
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package scala.runtime
2+
3+
import scala.collection.immutable.TreeMap
4+
5+
class EnumValues[E <: Enum] {
6+
private[this] var myMap: Map[Int, E] = TreeMap.empty
7+
private[this] var fromNameCache: Map[String, E] = null
8+
9+
def register(v: E) = {
10+
require(!myMap.contains(v.ordinal))
11+
myMap = myMap.updated(v.ordinal, v)
12+
fromNameCache = null
13+
}
14+
15+
def fromInt: Map[Int, E] = myMap
16+
def fromName: Map[String, E] = {
17+
if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.enumLabel -> v).toMap
18+
fromNameCache
19+
}
20+
def values: Iterable[E] = myMap.values
21+
}

0 commit comments

Comments
 (0)