Skip to content

Commit 2cb2547

Browse files
authored
fix: correctly codegen defaults for enum shapes (#944)
1 parent 7d34ee1 commit 2cb2547

File tree

3 files changed

+35
-15
lines changed

3 files changed

+35
-15
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "855b179f-a977-459f-b7d7-fc0bccd208d7",
3+
"type": "bugfix",
4+
"description": "Correctly codegen defaults for enum shapes"
5+
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/core/KotlinSymbolProvider.kt

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,28 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
212212
}
213213
}
214214

215-
private fun DefaultTrait.getDefaultValue(targetShape: Shape): String? = when {
216-
toNode().toString() == "null" || targetShape is BlobShape && toNode().toString() == "" -> null
217-
toNode().isNumberNode -> getDefaultValueForNumber(targetShape, toNode().toString())
218-
toNode().isArrayNode -> "listOf()"
219-
toNode().isObjectNode -> "mapOf()"
220-
toNode().isStringNode -> toNode().toString().dq()
221-
else -> toNode().toString()
215+
private fun DefaultTrait.getDefaultValue(targetShape: Shape): String? {
216+
val node = toNode()
217+
return when {
218+
node.toString() == "null" || targetShape is BlobShape && node.toString() == "" -> null
219+
220+
// Check if target is an enum before treating the default like a regular number/string
221+
targetShape.isEnum -> {
222+
val enumSymbol = toSymbol(targetShape)
223+
val arg = when {
224+
targetShape.isStringShape -> node.toString().dq()
225+
targetShape.isIntEnumShape -> getDefaultValueForNumber(ShapeType.INTEGER, node.toString())
226+
else -> throw CodegenException("Unknown enum type for $targetShape")
227+
}
228+
"${enumSymbol.fullName}.fromValue($arg)"
229+
}
230+
231+
node.isNumberNode -> getDefaultValueForNumber(targetShape.type, node.toString())
232+
node.isArrayNode -> "listOf()"
233+
node.isObjectNode -> "mapOf()"
234+
node.isStringNode -> node.toString().dq()
235+
else -> node.toString()
236+
}
222237
}
223238

224239
override fun timestampShape(shape: TimestampShape?): Symbol {
@@ -277,12 +292,12 @@ class KotlinSymbolProvider(private val model: Model, private val settings: Kotli
277292
return builder
278293
}
279294

280-
private fun getDefaultValueForNumber(shape: Shape, value: String) = when (shape) {
281-
is LongShape -> "${value}L"
282-
is FloatShape -> "${value}f"
283-
is DoubleShape -> if (value.matches("[0-9]*\\.[0-9]+".toRegex())) value else "$value.0"
284-
is ShortShape -> "$value.toShort()"
285-
is ByteShape -> "$value.toByte()"
295+
private fun getDefaultValueForNumber(type: ShapeType, value: String) = when (type) {
296+
ShapeType.LONG -> "${value}L"
297+
ShapeType.FLOAT -> "${value}f"
298+
ShapeType.DOUBLE -> if (value.matches("[0-9]*\\.[0-9]+".toRegex())) value else "$value.0"
299+
ShapeType.SHORT -> "$value.toShort()"
300+
ShapeType.BYTE -> "$value.toByte()"
286301
else -> value
287302
}
288303

codegen/smithy-kotlin-codegen/src/test/kotlin/software/amazon/smithy/kotlin/codegen/core/SymbolProviderTest.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class SymbolProviderTest {
246246
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model)
247247
val member = model.expectShape<MemberShape>("com.test#MyStruct\$foo")
248248
val memberSymbol = provider.toSymbol(member)
249-
assertEquals("\"club\"", memberSymbol.defaultValue())
249+
assertEquals("""com.test.model.Suit.fromValue("club")""", memberSymbol.defaultValue())
250250
}
251251

252252
@Test
@@ -268,7 +268,7 @@ class SymbolProviderTest {
268268
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model)
269269
val member = model.expectShape<MemberShape>("com.test#MyStruct\$foo")
270270
val memberSymbol = provider.toSymbol(member)
271-
assertEquals("2", memberSymbol.defaultValue())
271+
assertEquals("com.test.model.Season.fromValue(2)", memberSymbol.defaultValue())
272272
}
273273

274274
@ParameterizedTest(name = "{index} ==> ''can default document with {0} type''")

0 commit comments

Comments
 (0)