Skip to content

Commit 44b8249

Browse files
authored
refactor XML deserialize (#1233)
1 parent a65dc90 commit 44b8249

File tree

3 files changed

+239
-99
lines changed

3 files changed

+239
-99
lines changed

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/AwsQuery.kt

Lines changed: 64 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingPr
1212
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
1313
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
1414
import software.amazon.smithy.kotlin.codegen.core.*
15+
import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1516
import software.amazon.smithy.kotlin.codegen.model.*
16-
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
1717
import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
1818
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
19-
import software.amazon.smithy.kotlin.codegen.utils.dq
2019
import software.amazon.smithy.model.shapes.*
2120
import software.amazon.smithy.model.traits.*
2221

@@ -68,24 +67,6 @@ private class AwsQuerySerdeFormUrlDescriptorGenerator(
6867
member.hasTrait<XmlFlattenedTrait>()
6968
}
7069

71-
private class AwsQuerySerdeXmlDescriptorGenerator(
72-
ctx: RenderingContext<Shape>,
73-
memberShapes: List<MemberShape>? = null,
74-
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {
75-
76-
override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
77-
val traits = super.getObjectDescriptorTraits().toMutableList()
78-
79-
if (objectShape.hasTrait<OperationOutput>()) {
80-
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
81-
val serialName = objectShape.changeNameSuffix("Response" to "Result")
82-
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
83-
}
84-
85-
return traits
86-
}
87-
}
88-
8970
private class AwsQuerySerializerGenerator(
9071
private val protocolGenerator: AwsQuery,
9172
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
@@ -98,50 +79,76 @@ private class AwsQuerySerializerGenerator(
9879
}
9980

10081
private class AwsQueryXmlParserGenerator(
101-
private val protocolGenerator: AwsQuery,
102-
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
103-
104-
override fun descriptorGenerator(
105-
ctx: ProtocolGenerator.GenerationContext,
106-
shape: Shape,
107-
members: List<MemberShape>,
108-
writer: KotlinWriter,
109-
): XmlSerdeDescriptorGenerator = AwsQuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)
110-
111-
override fun renderDeserializeOperationBody(
112-
ctx: ProtocolGenerator.GenerationContext,
113-
op: OperationShape,
114-
documentMembers: List<MemberShape>,
115-
writer: KotlinWriter,
116-
) {
117-
writer.write("val deserializer = #T(payload)", RuntimeTypes.Serde.SerdeXml.XmlDeserializer)
118-
unwrapOperationResponseBody(op.id.name, writer)
119-
val shape = ctx.model.expectShape(op.output.get())
120-
renderDeserializerBody(ctx, shape, documentMembers, writer)
121-
}
82+
protocolGenerator: AwsQuery,
83+
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {
12284

12385
/**
12486
* Unwraps the response body as specified by
12587
* https://awslabs.github.io/smithy/1.0/spec/aws/aws-query-protocol.html#response-serialization so that the
12688
* deserializer is in the correct state.
89+
*
90+
* ```
91+
* <SomeOperationResponse>
92+
* <SomeOperationResult>
93+
* <-- SAME AS REST XML -->
94+
* </SomeOperationResult>
95+
*</SomeOperationResponse>
96+
* ```
12797
*/
128-
private fun unwrapOperationResponseBody(
129-
operationName: String,
98+
override fun unwrapOperationBody(
99+
ctx: ProtocolGenerator.GenerationContext,
100+
serdeCtx: SerdeCtx,
101+
op: OperationShape,
130102
writer: KotlinWriter,
131-
) {
132-
writer.write("// begin unwrap response wrapper")
133-
.write("val resultDescriptor = #T(#T.Struct, #T(#S))", RuntimeTypes.Serde.SdkFieldDescriptor, RuntimeTypes.Serde.SerialKind, RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Result")
134-
.withBlock("val wrapperDescriptor = #T.build {", "}", RuntimeTypes.Serde.SdkObjectDescriptor) {
135-
write("trait(#T(#S))", RuntimeTypes.Serde.SerdeXml.XmlSerialName, "${operationName}Response")
136-
write("#T(resultDescriptor)", RuntimeTypes.Serde.field)
103+
): SerdeCtx {
104+
val operationName = op.id.getName(ctx.service)
105+
106+
val unwrapAwsQueryOperation = buildSymbol {
107+
name = "unwrapAwsQueryResponse"
108+
namespace = ctx.settings.pkg.serde
109+
definitionFile = "AwsQueryUtil.kt"
110+
renderBy = { writer ->
111+
112+
writer.withBlock(
113+
"internal fun $name(root: #1T, operationName: #2T): #1T {",
114+
"}",
115+
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
116+
KotlinTypes.String,
117+
) {
118+
write("val responseWrapperName = \"\${operationName}Response\"")
119+
write("val resultWrapperName = \"\${operationName}Result\"")
120+
withBlock(
121+
"if (root.tagName != responseWrapperName) {",
122+
"}",
123+
) {
124+
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected \$responseWrapperName; found `\${root.tag}`")
125+
}
126+
127+
write("val resultTag = ${serdeCtx.tagReader}.nextTag()")
128+
withBlock(
129+
"if (resultTag == null || resultTag.tagName != resultWrapperName) {",
130+
"}",
131+
) {
132+
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid result, expected \$resultWrapperName; found `\${resultTag?.tag}`")
133+
}
134+
135+
write("return resultTag")
136+
}
137137
}
138-
.write("")
139-
// abandon the iterator, this only occurs at the top level operational output
140-
.write("val wrapper = deserializer.#T(wrapperDescriptor)", RuntimeTypes.Serde.deserializeStruct)
141-
.withBlock("if (wrapper.findNextFieldIndex() != resultDescriptor.index) {", "}") {
142-
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "failed to unwrap $operationName response")
143-
}
144-
.write("// end unwrap response wrapper")
145-
.write("")
138+
}
139+
140+
writer.write("val unwrapped = #T(#L, #S)", unwrapAwsQueryOperation, serdeCtx.tagReader, operationName)
141+
142+
return SerdeCtx("unwrapped")
143+
}
144+
145+
override fun unwrapOperationError(
146+
ctx: ProtocolGenerator.GenerationContext,
147+
serdeCtx: SerdeCtx,
148+
errorShape: StructureShape,
149+
writer: KotlinWriter,
150+
): SerdeCtx {
151+
writer.write("val errReader = #T(${serdeCtx.tagReader})", RestXmlErrors.wrappedErrorResponseDeserializer(ctx))
152+
return SerdeCtx("errReader")
146153
}
147154
}

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/software/amazon/smithy/kotlin/codegen/aws/protocols/Ec2Query.kt

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,19 @@ package software.amazon.smithy.kotlin.codegen.aws.protocols
66

77
import software.amazon.smithy.aws.traits.protocols.Ec2QueryNameTrait
88
import software.amazon.smithy.aws.traits.protocols.Ec2QueryTrait
9+
import software.amazon.smithy.codegen.core.Symbol
910
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AbstractQueryFormUrlSerializerGenerator
1011
import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
1112
import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
1213
import software.amazon.smithy.kotlin.codegen.core.KotlinWriter
1314
import software.amazon.smithy.kotlin.codegen.core.RenderingContext
1415
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
15-
import software.amazon.smithy.kotlin.codegen.model.changeNameSuffix
16+
import software.amazon.smithy.kotlin.codegen.core.withBlock
17+
import software.amazon.smithy.kotlin.codegen.model.buildSymbol
1618
import software.amazon.smithy.kotlin.codegen.model.getTrait
17-
import software.amazon.smithy.kotlin.codegen.model.hasTrait
18-
import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
1919
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
2020
import software.amazon.smithy.kotlin.codegen.rendering.protocol.toRenderingContext
2121
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
22-
import software.amazon.smithy.kotlin.codegen.utils.dq
2322
import software.amazon.smithy.model.shapes.*
2423
import software.amazon.smithy.model.traits.XmlNameTrait
2524

@@ -73,24 +72,6 @@ private class Ec2QuerySerdeFormUrlDescriptorGenerator(
7372
targetShape.type == ShapeType.LIST
7473
}
7574

76-
private class Ec2QuerySerdeXmlDescriptorGenerator(
77-
ctx: RenderingContext<Shape>,
78-
memberShapes: List<MemberShape>? = null,
79-
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {
80-
81-
override fun getObjectDescriptorTraits(): List<SdkFieldDescriptorTrait> {
82-
val traits = super.getObjectDescriptorTraits().toMutableList()
83-
84-
if (objectShape.hasTrait<OperationOutput>()) {
85-
traits.removeIf { it.symbol == RuntimeTypes.Serde.SerdeXml.XmlSerialName }
86-
val serialName = objectShape.changeNameSuffix("Response" to "Result")
87-
traits.add(RuntimeTypes.Serde.SerdeXml.XmlSerialName, serialName.dq())
88-
}
89-
90-
return traits
91-
}
92-
}
93-
9475
private class Ec2QuerySerializerGenerator(
9576
private val protocolGenerator: Ec2Query,
9677
) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
@@ -104,13 +85,73 @@ private class Ec2QuerySerializerGenerator(
10485
}
10586

10687
private class Ec2QueryParserGenerator(
107-
private val protocolGenerator: Ec2Query,
108-
) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
109-
110-
override fun descriptorGenerator(
88+
protocolGenerator: Ec2Query,
89+
) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {
90+
override fun unwrapOperationError(
11191
ctx: ProtocolGenerator.GenerationContext,
112-
shape: Shape,
113-
members: List<MemberShape>,
92+
serdeCtx: SerdeCtx,
93+
errorShape: StructureShape,
11494
writer: KotlinWriter,
115-
): XmlSerdeDescriptorGenerator = Ec2QuerySerdeXmlDescriptorGenerator(ctx.toRenderingContext(protocolGenerator, shape, writer), members)
95+
): SerdeCtx {
96+
val unwrapFn = unwrapErrorResponse(ctx)
97+
writer.write("val errReader = #T(${serdeCtx.tagReader})", unwrapFn)
98+
return SerdeCtx("errReader")
99+
}
100+
101+
/**
102+
* Error deserializer for a wrapped error response
103+
*
104+
* ```
105+
* <Response>
106+
* <Errors>
107+
* <Error>
108+
* <-- DATA -->>
109+
* </Error>
110+
* </Errors>
111+
* </Response>
112+
* ```
113+
*
114+
* See https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization
115+
*/
116+
private fun unwrapErrorResponse(ctx: ProtocolGenerator.GenerationContext): Symbol = buildSymbol {
117+
name = "unwrapXmlErrorResponse"
118+
namespace = ctx.settings.pkg.serde
119+
definitionFile = "XmlErrorUtils.kt"
120+
renderBy = { writer ->
121+
writer.dokka("Handle [wrapped](https://smithy.io/2.0/aws/protocols/aws-ec2-query-protocol.html#operation-error-serialization) error responses")
122+
writer.withBlock(
123+
"internal fun $name(root: #1T): #1T {",
124+
"}",
125+
RuntimeTypes.Serde.SerdeXml.XmlTagReader,
126+
) {
127+
withBlock(
128+
"if (root.tagName != #S) {",
129+
"}",
130+
"Response",
131+
) {
132+
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid root, expected <Response>; found `\${root.tag}`")
133+
}
134+
135+
write("val errorsTag = root.nextTag()")
136+
withBlock(
137+
"if (errorsTag == null || errorsTag.tagName != #S) {",
138+
"}",
139+
"Errors",
140+
) {
141+
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Errors>; found `\${errorsTag?.tag}`")
142+
}
143+
144+
write("val errTag = errorsTag.nextTag()")
145+
withBlock(
146+
"if (errTag == null || errTag.tagName != #S) {",
147+
"}",
148+
"Error",
149+
) {
150+
write("throw #T(#S)", RuntimeTypes.Serde.DeserializationException, "invalid error, expected <Error>; found `\${errTag?.tag}`")
151+
}
152+
153+
write("return errTag")
154+
}
155+
}
156+
}
116157
}

0 commit comments

Comments
 (0)