Skip to content

Commit b759e19

Browse files
authored
feat: support initial-response in RPC based event streams (#597)
1 parent 81a5043 commit b759e19

File tree

4 files changed

+170
-4
lines changed

4 files changed

+170
-4
lines changed

smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/httpResponse/bindingTraits/HttpResponseTraitWithoutHttpPayload.kt

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ class HttpResponseTraitWithoutHttpPayload(
5050
.filter { !it.member.hasTrait(HttpQueryTrait::class.java) }
5151
.toMutableSet()
5252
val streamingMember = bodyMembers.firstOrNull { it.member.targetOrSelf(ctx.model).hasTrait(StreamingTrait::class.java) }
53-
5453
if (streamingMember != null) {
55-
writeStreamingMember(streamingMember)
54+
val initialResponseMembers = bodyMembers.filter {
55+
val targetShape = it.member.targetOrSelf(ctx.model)
56+
targetShape?.hasTrait(StreamingTrait::class.java) == false
57+
}.toSet()
58+
writeStreamingMember(streamingMember, initialResponseMembers)
5659
} else if (bodyMembersWithoutQueryTrait.isNotEmpty()) {
5760
writeNonStreamingMembers(bodyMembersWithoutQueryTrait)
5861
}
5962
}
6063

61-
fun writeStreamingMember(streamingMember: HttpBindingDescriptor) {
64+
fun writeStreamingMember(streamingMember: HttpBindingDescriptor, initialResponseMembers: Set<HttpBindingDescriptor>) {
6265
val shape = ctx.model.expectShape(streamingMember.member.target)
6366
val symbol = ctx.symbolProvider.toSymbol(shape)
6467
val memberName = ctx.symbolProvider.toMemberName(streamingMember.member)
@@ -74,6 +77,9 @@ class HttpResponseTraitWithoutHttpPayload(
7477
symbol
7578
)
7679
writer.write("self.\$L = decoderStream.toAsyncStream()", memberName)
80+
if (isRPCService(ctx) && initialResponseMembers.isNotEmpty()) {
81+
writeInitialResponseMembers(initialResponseMembers)
82+
}
7783
}
7884
writer.indent()
7985
writer.write("self.\$L = nil", memberName).closeBlock("}")
@@ -133,4 +139,52 @@ class HttpResponseTraitWithoutHttpPayload(
133139
}
134140

135141
private val path: String = "properties.".takeIf { outputShape.hasTrait<ErrorTrait>() } ?: ""
142+
143+
private fun writeInitialResponseMembers(initialResponseMembers: Set<HttpBindingDescriptor>) {
144+
writer.apply {
145+
write("if let initialDataWithoutHttp = await messageDecoder.awaitInitialResponse() {")
146+
indent()
147+
write("let decoder = JSONDecoder()")
148+
write("do {")
149+
indent()
150+
write("let response = try decoder.decode([String: String].self, from: initialDataWithoutHttp)")
151+
initialResponseMembers.forEach { responseMember ->
152+
val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member)
153+
write("self.$responseMemberName = response[\"$responseMemberName\"].map { value in KinesisClientTypes.Tag(value: value) }")
154+
}
155+
dedent()
156+
write("} catch {")
157+
indent()
158+
write("print(\"Error decoding JSON: \\(error)\")")
159+
initialResponseMembers.forEach { responseMember ->
160+
val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member)
161+
write("self.$responseMemberName = nil")
162+
}
163+
dedent()
164+
write("}")
165+
dedent()
166+
write("} else {")
167+
indent()
168+
initialResponseMembers.forEach { responseMember ->
169+
val responseMemberName = ctx.symbolProvider.toMemberName(responseMember.member)
170+
write("self.$responseMemberName = nil")
171+
}
172+
dedent()
173+
write("}")
174+
}
175+
}
176+
177+
private fun isRPCService(ctx: ProtocolGenerator.GenerationContext): Boolean {
178+
return rpcBoundProtocols.contains(ctx.protocol.name)
179+
}
180+
181+
/**
182+
* A set of RPC-bound Smithy protocols
183+
*/
184+
private val rpcBoundProtocols = setOf(
185+
"awsJson1_0",
186+
"awsJson1_1",
187+
"awsQuery",
188+
"ec2Query",
189+
)
136190
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
import io.kotest.matchers.string.shouldContainOnlyOnce
7+
import mocks.MockHttpAWSJson11ProtocolGenerator
8+
import org.junit.jupiter.api.Test
9+
import software.amazon.smithy.swift.codegen.integration.HttpBindingProtocolGenerator
10+
11+
class EventStreamsInitialResponseTests {
12+
@Test
13+
fun `should attempt to decode response if initial-response members are present in RPC (awsJson) smithy model`() {
14+
val context = setupInitialMessageTests(
15+
"event-stream-initial-request-response.smithy",
16+
"com.test#Example",
17+
MockHttpAWSJson11ProtocolGenerator()
18+
)
19+
val contents = getFileContents(
20+
context.manifest,
21+
"/InitialMessageEventStreams/models/TestStreamOperationWithInitialRequestResponseOutput+HttpResponseBinding.swift"
22+
)
23+
contents.shouldSyntacticSanityCheck()
24+
contents.shouldContainOnlyOnce(
25+
"""
26+
extension TestStreamOperationWithInitialRequestResponseOutput: ClientRuntime.HttpResponseBinding {
27+
public init(httpResponse: ClientRuntime.HttpResponse, decoder: ClientRuntime.ResponseDecoder? = nil) async throws {
28+
if case let .stream(stream) = httpResponse.body, let responseDecoder = decoder {
29+
let messageDecoder: ClientRuntime.MessageDecoder? = nil
30+
let decoderStream = ClientRuntime.EventStream.DefaultMessageDecoderStream<InitialMessageEventStreamsClientTypes.TestStream>(stream: stream, messageDecoder: messageDecoder, responseDecoder: responseDecoder)
31+
self.value = decoderStream.toAsyncStream()
32+
if let initialDataWithoutHttp = await messageDecoder.awaitInitialResponse() {
33+
let decoder = JSONDecoder()
34+
do {
35+
let response = try decoder.decode([String: String].self, from: initialDataWithoutHttp)
36+
self.initial1 = response["initial1"].map { value in KinesisClientTypes.Tag(value: value) }
37+
self.initial2 = response["initial2"].map { value in KinesisClientTypes.Tag(value: value) }
38+
} catch {
39+
print("Error decoding JSON: \(error)")
40+
self.initial1 = nil
41+
self.initial2 = nil
42+
}
43+
} else {
44+
self.initial1 = nil
45+
self.initial2 = nil
46+
}
47+
} else {
48+
self.value = nil
49+
}
50+
}
51+
}
52+
""".trimIndent()
53+
)
54+
}
55+
56+
private fun setupInitialMessageTests(
57+
smithyFile: String,
58+
serviceShapeId: String,
59+
protocolGenerator: HttpBindingProtocolGenerator
60+
): TestContext {
61+
val context = TestContext.initContextFrom(smithyFile, serviceShapeId, protocolGenerator) { model ->
62+
model.defaultSettings(serviceShapeId, "InitialMessageEventStreams", "123", "InitialMessageEventStreams")
63+
}
64+
context.generator.initializeMiddleware(context.generationCtx)
65+
context.generator.generateSerializers(context.generationCtx)
66+
context.generator.generateProtocolClient(context.generationCtx)
67+
context.generator.generateDeserializers(context.generationCtx)
68+
context.generator.generateCodableConformanceForNestedTypes(context.generationCtx)
69+
context.generationCtx.delegator.flushWriters()
70+
return context
71+
}
72+
}

smithy-swift-codegen/src/test/kotlin/mocks/MockHttpAWSJson11ProtocolGenerator.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class MockAWSJson11HttpProtocolCustomizations() : DefaultHttpProtocolCustomizati
5353
writer: SwiftWriter,
5454
op: OperationShape,
5555
) {
56-
TODO("Not yet implemented")
56+
// Not yet implemented
57+
return
5758
}
5859
}
5960

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
namespace com.test
2+
3+
use aws.protocols#awsJson1_1
4+
use aws.api#service
5+
use aws.auth#sigv4
6+
7+
@awsJson1_1
8+
@sigv4(name: "event-stream-test")
9+
@service(sdkId: "InitialMessageEventStreams")
10+
service Example {
11+
version: "123",
12+
operations: [TestStreamOperationWithInitialRequestResponse]
13+
}
14+
15+
operation TestStreamOperationWithInitialRequestResponse {
16+
input: TestStreamInputOutputInitialRequestResponse,
17+
output: TestStreamInputOutputInitialRequestResponse,
18+
errors: [SomeError],
19+
}
20+
21+
structure TestStreamInputOutputInitialRequestResponse {
22+
@required
23+
value: TestStream
24+
initial1: String
25+
initial2: String
26+
}
27+
28+
@error("client")
29+
structure SomeError {
30+
Message: String,
31+
}
32+
33+
structure MessageWithString { @eventPayload data: String }
34+
35+
@streaming
36+
union TestStream {
37+
MessageWithString: MessageWithString,
38+
SomeError: SomeError,
39+
}

0 commit comments

Comments
 (0)