Skip to content

Commit 4f06399

Browse files
authored
feat(codegen): improve nullability of generated types (#968)
1 parent 94b9904 commit 4f06399

File tree

40 files changed

+1322
-489
lines changed

40 files changed

+1322
-489
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "0166e037-d43a-48cd-9deb-8dacd55f9aeb",
3+
"type": "feature",
4+
"description": "Refactor codegen to support treating `@required` members as non-nullable."
5+
}

build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ apiValidation {
137137
"http-benchmarks",
138138
"serde-benchmarks",
139139
"serde-benchmarks-codegen",
140+
"nullability-tests",
140141
"paginator-tests",
141142
"waiter-tests",
142143
"compile",

codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/CodegenTestUtils.kt

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import software.amazon.smithy.build.MockManifest
1414
import software.amazon.smithy.codegen.core.Symbol
1515
import software.amazon.smithy.codegen.core.SymbolProvider
1616
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
17+
import software.amazon.smithy.kotlin.codegen.KotlinSettings
1718
import software.amazon.smithy.kotlin.codegen.core.*
1819
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
1920
import software.amazon.smithy.kotlin.codegen.model.expectShape
@@ -51,8 +52,14 @@ fun testRender(
5152
}
5253

5354
/** Drive codegen for serialization of a given shape */
54-
fun codegenSerializerForShape(model: Model, shapeId: String, location: HttpBinding.Location = HttpBinding.Location.DOCUMENT): String {
55-
val ctx = model.newTestContext()
55+
fun codegenSerializerForShape(
56+
model: Model,
57+
shapeId: String,
58+
location: HttpBinding.Location = HttpBinding.Location.DOCUMENT,
59+
settings: KotlinSettings? = null,
60+
): String {
61+
val resolvedSettings = settings ?: model.defaultSettings(TestModelDefault.SERVICE_NAME, TestModelDefault.NAMESPACE)
62+
val ctx = model.newTestContext(settings = resolvedSettings)
5663

5764
val op = ctx.generationCtx.model.expectShape(ShapeId.from(shapeId))
5865
return testRender(ctx.requestMembers(op, location)) { members, writer ->
@@ -250,9 +257,15 @@ fun generateCode(generator: (KotlinWriter) -> Unit): String {
250257
return rawCodegen.substring(rawCodegen.indexOf(packageDeclaration) + packageDeclaration.length).trim()
251258
}
252259

253-
fun KotlinCodegenPlugin.Companion.createSymbolProvider(model: Model, rootNamespace: String = TestModelDefault.NAMESPACE, sdkId: String = TestModelDefault.SDK_ID, serviceName: String = TestModelDefault.SERVICE_NAME): SymbolProvider {
254-
val settings = model.defaultSettings(serviceName = serviceName, packageName = rootNamespace, sdkId = sdkId)
255-
return createSymbolProvider(model, settings)
260+
fun KotlinCodegenPlugin.Companion.createSymbolProvider(
261+
model: Model,
262+
rootNamespace: String = TestModelDefault.NAMESPACE,
263+
sdkId: String = TestModelDefault.SDK_ID,
264+
serviceName: String = TestModelDefault.SERVICE_NAME,
265+
settings: KotlinSettings? = null,
266+
): SymbolProvider {
267+
val resolvedSettings = settings ?: model.defaultSettings(serviceName = serviceName, packageName = rootNamespace, sdkId = sdkId)
268+
return createSymbolProvider(model, resolvedSettings)
256269
}
257270

258271
/**

codegen/smithy-kotlin-codegen-testutils/src/main/kotlin/software/amazon/smithy/kotlin/codegen/test/ModelTestUtils.kt

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@ package software.amazon.smithy.kotlin.codegen.test
66

77
import software.amazon.smithy.build.MockManifest
88
import software.amazon.smithy.codegen.core.SymbolProvider
9-
import software.amazon.smithy.kotlin.codegen.KotlinCodegenPlugin
10-
import software.amazon.smithy.kotlin.codegen.KotlinSettings
9+
import software.amazon.smithy.kotlin.codegen.*
1110
import software.amazon.smithy.kotlin.codegen.core.CodegenContext
1211
import software.amazon.smithy.kotlin.codegen.core.KotlinDelegator
13-
import software.amazon.smithy.kotlin.codegen.inferService
1412
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
1513
import software.amazon.smithy.kotlin.codegen.model.OperationNormalizer
1614
import software.amazon.smithy.kotlin.codegen.model.shapes
1715
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
1816
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
1917
import software.amazon.smithy.model.Model
18+
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
2019
import software.amazon.smithy.model.node.Node
2120
import software.amazon.smithy.model.shapes.ServiceShape
2221
import software.amazon.smithy.model.shapes.ShapeId
@@ -116,12 +115,12 @@ fun Model.toSmithyIDL(): String {
116115
fun Model.newTestContext(
117116
serviceName: String = TestModelDefault.SERVICE_NAME,
118117
packageName: String = TestModelDefault.NAMESPACE,
119-
settings: KotlinSettings = this.defaultSettings(serviceName, packageName),
118+
settings: KotlinSettings = defaultSettings(serviceName, packageName),
120119
generator: ProtocolGenerator = MockHttpProtocolGenerator(this),
121120
integrations: List<KotlinIntegration> = listOf(),
122121
): TestContext {
123122
val manifest = MockManifest()
124-
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName)
123+
val provider: SymbolProvider = KotlinCodegenPlugin.createSymbolProvider(model = this, rootNamespace = packageName, serviceName = serviceName, settings = settings)
125124
val service = this.getShape(ShapeId.from("$packageName#$serviceName")).get().asServiceShape().get()
126125
val delegator = KotlinDelegator(settings, this, manifest, provider)
127126

@@ -173,6 +172,8 @@ fun Model.defaultSettings(
173172
packageVersion: String = TestModelDefault.MODEL_VERSION,
174173
sdkId: String = TestModelDefault.SDK_ID,
175174
generateDefaultBuildFiles: Boolean = false,
175+
nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL,
176+
defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT,
176177
): KotlinSettings {
177178
val serviceId = if (serviceName == null) {
178179
this.inferService()
@@ -197,6 +198,12 @@ fun Model.defaultSettings(
197198
Node.objectNode()
198199
.withMember("generateDefaultBuildFiles", Node.from(generateDefaultBuildFiles)),
199200
)
201+
.withMember(
202+
"api",
203+
Node.objectNode()
204+
.withMember(ApiSettings.NULLABILITY_CHECK_MODE, Node.from(nullabilityCheckMode.kotlinPluginSetting))
205+
.withMember(ApiSettings.DEFAULT_VALUE_SERIALIZATION_MODE, Node.from(defaultValueSerializationMode.value)),
206+
)
200207
.build(),
201208
)
202209
}

codegen/smithy-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/KotlinSettings.kt

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,19 @@ package software.amazon.smithy.kotlin.codegen
77

88
import software.amazon.smithy.codegen.core.CodegenException
99
import software.amazon.smithy.kotlin.codegen.lang.isValidPackageName
10+
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
11+
import software.amazon.smithy.kotlin.codegen.utils.toCamelCase
1012
import software.amazon.smithy.model.Model
13+
import software.amazon.smithy.model.knowledge.NullableIndex.CheckMode
1114
import software.amazon.smithy.model.knowledge.ServiceIndex
1215
import software.amazon.smithy.model.node.ObjectNode
1316
import software.amazon.smithy.model.node.StringNode
1417
import software.amazon.smithy.model.shapes.ServiceShape
1518
import software.amazon.smithy.model.shapes.Shape
1619
import software.amazon.smithy.model.shapes.ShapeId
17-
import java.lang.IllegalArgumentException
1820
import java.util.Optional
1921
import java.util.logging.Logger
22+
import kotlin.IllegalArgumentException
2023
import kotlin.streams.toList
2124

2225
// shapeId of service from which to generate an SDK
@@ -81,7 +84,7 @@ data class KotlinSettings(
8184
* @return Returns the extracted settings
8285
*/
8386
fun from(model: Model, config: ObjectNode): KotlinSettings {
84-
config.warnIfAdditionalProperties(listOf(SERVICE, PACKAGE_SETTINGS, BUILD_SETTINGS, SDK_ID))
87+
config.warnIfAdditionalProperties(listOf(SERVICE, PACKAGE_SETTINGS, BUILD_SETTINGS, SDK_ID, API_SETTINGS))
8588

8689
val serviceId = config.getStringMember(SERVICE)
8790
.map(StringNode::expectShapeId)
@@ -224,19 +227,70 @@ enum class Visibility(val value: String) {
224227
}
225228
}
226229

230+
private fun checkModefromValue(value: String): CheckMode {
231+
val camelCaseToMode = CheckMode.values().associateBy { it.toString().toCamelCase() }
232+
return requireNotNull(camelCaseToMode[value]) { "$value is not a valid CheckMode, expected one of ${camelCaseToMode.keys}" }
233+
}
234+
235+
/**
236+
* Get the plugin setting for this check mode
237+
*/
238+
val CheckMode.kotlinPluginSetting: String
239+
get() = toString().toCamelCase()
240+
241+
enum class DefaultValueSerializationMode(val value: String) {
242+
/**
243+
* Always serialize values even if they are set to the modeled default
244+
*/
245+
ALWAYS("always"),
246+
247+
/**
248+
* Only serialize values when they differ from the modeled default or are marked `@required`
249+
*/
250+
WHEN_DIFFERENT("whenDifferent"),
251+
;
252+
override fun toString(): String = value
253+
companion object {
254+
fun fromValue(value: String): DefaultValueSerializationMode =
255+
values().find {
256+
it.value == value
257+
} ?: throw IllegalArgumentException("$value is not a valid DefaultValueSerializationMode, expected one of ${values().map { it.value }}")
258+
}
259+
}
260+
227261
/**
228262
* Contains API settings for a Kotlin project
229263
* @param visibility Enum representing the visibility of code-generated classes, objects, interfaces, etc.
264+
* @param nullabilityCheckMode Enum representing the nullability check mode to use
265+
* @param defaultValueSerializationMode Enum representing when default values should be serialized
230266
*/
231267
data class ApiSettings(
232268
val visibility: Visibility = Visibility.PUBLIC,
269+
val nullabilityCheckMode: CheckMode = CheckMode.CLIENT_CAREFUL,
270+
val defaultValueSerializationMode: DefaultValueSerializationMode = DefaultValueSerializationMode.WHEN_DIFFERENT,
233271
) {
234272
companion object {
235273
const val VISIBILITY = "visibility"
274+
const val NULLABILITY_CHECK_MODE = "nullabilityCheckMode"
275+
const val DEFAULT_VALUE_SERIALIZATION_MODE = "defaultValueSerializationMode"
236276

237277
fun fromNode(node: Optional<ObjectNode>): ApiSettings = node.map {
238-
val visibility = Visibility.fromValue(node.get().getStringMemberOrDefault(VISIBILITY, "public"))
239-
ApiSettings(visibility)
278+
val visibility = node.get()
279+
.getStringMember(VISIBILITY)
280+
.map { Visibility.fromValue(it.value) }
281+
.getOrNull() ?: Visibility.PUBLIC
282+
val checkMode = node.get()
283+
.getStringMember(NULLABILITY_CHECK_MODE)
284+
.map { checkModefromValue(it.value) }
285+
.getOrNull() ?: CheckMode.CLIENT_CAREFUL
286+
val defaultValueSerializationMode = DefaultValueSerializationMode.fromValue(
287+
node.get()
288+
.getStringMemberOrDefault(
289+
DEFAULT_VALUE_SERIALIZATION_MODE,
290+
DefaultValueSerializationMode.WHEN_DIFFERENT.value,
291+
),
292+
)
293+
ApiSettings(visibility, checkMode, defaultValueSerializationMode)
240294
}.orElse(Default)
241295

242296
/**

0 commit comments

Comments
 (0)