@@ -12,11 +12,10 @@ import software.amazon.smithy.kotlin.codegen.aws.protocols.core.AwsHttpBindingPr
1212import software.amazon.smithy.kotlin.codegen.aws.protocols.core.QueryHttpBindingProtocolGenerator
1313import software.amazon.smithy.kotlin.codegen.aws.protocols.formurl.QuerySerdeFormUrlDescriptorGenerator
1414import software.amazon.smithy.kotlin.codegen.core.*
15+ import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1516import software.amazon.smithy.kotlin.codegen.model.*
16- import software.amazon.smithy.kotlin.codegen.model.traits.OperationOutput
1717import software.amazon.smithy.kotlin.codegen.rendering.protocol.*
1818import software.amazon.smithy.kotlin.codegen.rendering.serde.*
19- import software.amazon.smithy.kotlin.codegen.utils.dq
2019import software.amazon.smithy.model.shapes.*
2120import software.amazon.smithy.model.traits.*
2221
@@ -68,24 +67,6 @@ private class AwsQuerySerdeFormUrlDescriptorGenerator(
6867 member.hasTrait<XmlFlattenedTrait >()
6968}
7069
71- private class AwsQuerySerdeXmlDescriptorGenerator (
72- ctx : RenderingContext <Shape >,
73- memberShapes : List <MemberShape >? = null ,
74- ) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {
75-
76- override fun getObjectDescriptorTraits (): List <SdkFieldDescriptorTrait > {
77- val traits = super .getObjectDescriptorTraits().toMutableList()
78-
79- if (objectShape.hasTrait<OperationOutput >()) {
80- traits.removeIf { it.symbol == RuntimeTypes .Serde .SerdeXml .XmlSerialName }
81- val serialName = objectShape.changeNameSuffix(" Response" to " Result" )
82- traits.add(RuntimeTypes .Serde .SerdeXml .XmlSerialName , serialName.dq())
83- }
84-
85- return traits
86- }
87- }
88-
8970private class AwsQuerySerializerGenerator (
9071 private val protocolGenerator : AwsQuery ,
9172) : AbstractQueryFormUrlSerializerGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
@@ -98,50 +79,76 @@ private class AwsQuerySerializerGenerator(
9879}
9980
10081private class AwsQueryXmlParserGenerator (
101- private val protocolGenerator : AwsQuery ,
102- ) : XmlParserGenerator(protocolGenerator, protocolGenerator.defaultTimestampFormat) {
103-
104- override fun descriptorGenerator (
105- ctx : ProtocolGenerator .GenerationContext ,
106- shape : Shape ,
107- members : List <MemberShape >,
108- writer : KotlinWriter ,
109- ): XmlSerdeDescriptorGenerator = AwsQuerySerdeXmlDescriptorGenerator (ctx.toRenderingContext(protocolGenerator, shape, writer), members)
110-
111- override fun renderDeserializeOperationBody (
112- ctx : ProtocolGenerator .GenerationContext ,
113- op : OperationShape ,
114- documentMembers : List <MemberShape >,
115- writer : KotlinWriter ,
116- ) {
117- writer.write(" val deserializer = #T(payload)" , RuntimeTypes .Serde .SerdeXml .XmlDeserializer )
118- unwrapOperationResponseBody(op.id.name, writer)
119- val shape = ctx.model.expectShape(op.output.get())
120- renderDeserializerBody(ctx, shape, documentMembers, writer)
121- }
82+ protocolGenerator : AwsQuery ,
83+ ) : XmlParserGenerator(protocolGenerator.defaultTimestampFormat) {
12284
12385 /* *
12486 * Unwraps the response body as specified by
12587 * https://awslabs.github.io/smithy/1.0/spec/aws/aws-query-protocol.html#response-serialization so that the
12688 * deserializer is in the correct state.
89+ *
90+ * ```
91+ * <SomeOperationResponse>
92+ * <SomeOperationResult>
93+ * <-- SAME AS REST XML -->
94+ * </SomeOperationResult>
95+ *</SomeOperationResponse>
96+ * ```
12797 */
128- private fun unwrapOperationResponseBody (
129- operationName : String ,
98+ override fun unwrapOperationBody (
99+ ctx : ProtocolGenerator .GenerationContext ,
100+ serdeCtx : SerdeCtx ,
101+ op : OperationShape ,
130102 writer : KotlinWriter ,
131- ) {
132- writer.write(" // begin unwrap response wrapper" )
133- .write(" val resultDescriptor = #T(#T.Struct, #T(#S))" , RuntimeTypes .Serde .SdkFieldDescriptor , RuntimeTypes .Serde .SerialKind , RuntimeTypes .Serde .SerdeXml .XmlSerialName , " ${operationName} Result" )
134- .withBlock(" val wrapperDescriptor = #T.build {" , " }" , RuntimeTypes .Serde .SdkObjectDescriptor ) {
135- write(" trait(#T(#S))" , RuntimeTypes .Serde .SerdeXml .XmlSerialName , " ${operationName} Response" )
136- write(" #T(resultDescriptor)" , RuntimeTypes .Serde .field)
103+ ): SerdeCtx {
104+ val operationName = op.id.getName(ctx.service)
105+
106+ val unwrapAwsQueryOperation = buildSymbol {
107+ name = " unwrapAwsQueryResponse"
108+ namespace = ctx.settings.pkg.serde
109+ definitionFile = " AwsQueryUtil.kt"
110+ renderBy = { writer ->
111+
112+ writer.withBlock(
113+ " internal fun $name (root: #1T, operationName: #2T): #1T {" ,
114+ " }" ,
115+ RuntimeTypes .Serde .SerdeXml .XmlTagReader ,
116+ KotlinTypes .String ,
117+ ) {
118+ write(" val responseWrapperName = \" \$ {operationName}Response\" " )
119+ write(" val resultWrapperName = \" \$ {operationName}Result\" " )
120+ withBlock(
121+ " if (root.tagName != responseWrapperName) {" ,
122+ " }" ,
123+ ) {
124+ write(" throw #T(#S)" , RuntimeTypes .Serde .DeserializationException , " invalid root, expected \$ responseWrapperName; found `\$ {root.tag}`" )
125+ }
126+
127+ write(" val resultTag = ${serdeCtx.tagReader} .nextTag()" )
128+ withBlock(
129+ " if (resultTag == null || resultTag.tagName != resultWrapperName) {" ,
130+ " }" ,
131+ ) {
132+ write(" throw #T(#S)" , RuntimeTypes .Serde .DeserializationException , " invalid result, expected \$ resultWrapperName; found `\$ {resultTag?.tag}`" )
133+ }
134+
135+ write(" return resultTag" )
136+ }
137137 }
138- .write(" " )
139- // abandon the iterator, this only occurs at the top level operational output
140- .write(" val wrapper = deserializer.#T(wrapperDescriptor)" , RuntimeTypes .Serde .deserializeStruct)
141- .withBlock(" if (wrapper.findNextFieldIndex() != resultDescriptor.index) {" , " }" ) {
142- write(" throw #T(#S)" , RuntimeTypes .Serde .DeserializationException , " failed to unwrap $operationName response" )
143- }
144- .write(" // end unwrap response wrapper" )
145- .write(" " )
138+ }
139+
140+ writer.write(" val unwrapped = #T(#L, #S)" , unwrapAwsQueryOperation, serdeCtx.tagReader, operationName)
141+
142+ return SerdeCtx (" unwrapped" )
143+ }
144+
145+ override fun unwrapOperationError (
146+ ctx : ProtocolGenerator .GenerationContext ,
147+ serdeCtx : SerdeCtx ,
148+ errorShape : StructureShape ,
149+ writer : KotlinWriter ,
150+ ): SerdeCtx {
151+ writer.write(" val errReader = #T(${serdeCtx.tagReader} )" , RestXmlErrors .wrappedErrorResponseDeserializer(ctx))
152+ return SerdeCtx (" errReader" )
146153 }
147154}
0 commit comments