Skip to content

Commit d41629d

Browse files
Update generated stream wrappers
1 parent 48c4786 commit d41629d

File tree

4 files changed

+43
-37
lines changed

4 files changed

+43
-37
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,13 @@ def _classify_error(
196196
if (hasStreaming) {
197197
writer.addStdlibImports("typing", Set.of("Any", "Awaitable"));
198198
writer.addStdlibImport("asyncio");
199+
200+
writer.addImports("smithy_core.aio.eventstream",
201+
Set.of(
202+
"InputEventStream",
203+
"OutputEventStream",
204+
"DuplexEventStream"));
205+
writer.addImport("smithy_core.aio.interfaces.eventstream", "EventReceiver");
199206
writer.write(
200207
"""
201208
async def _input_stream(
@@ -214,6 +221,10 @@ async def _input_stream(
214221
))
215222
request_context = await request_future
216223
${5C|}
224+
return InputEventStream[Any, Any](
225+
input_stream=publisher,
226+
output_future=awaitable_output,
227+
)
217228
218229
async def _output_stream(
219230
self,
@@ -232,6 +243,10 @@ async def _output_stream(
232243
)
233244
transport_response = await response_future
234245
${6C|}
246+
return OutputEventStream[Any, Any](
247+
output_stream=receiver,
248+
output=output
249+
)
235250
236251
async def _duplex_stream(
237252
self,
@@ -251,15 +266,34 @@ async def _duplex_stream(
251266
response_future=response_future
252267
))
253268
request_context = await request_future
254-
${7C|}
269+
${5C|}
270+
output_future = asyncio.create_task(self._wrap_duplex_output(
271+
response_future, awaitable_output, config, operation_name,
272+
event_deserializer
273+
))
274+
return DuplexEventStream[Any, Any, Any](
275+
input_stream=publisher,
276+
output_future=output_future,
277+
)
278+
279+
async def _wrap_duplex_output(
280+
self,
281+
response_future: Future[$3T],
282+
awaitable_output: Future[Any],
283+
config: $4T,
284+
operation_name: str,
285+
event_deserializer: Callable[[ShapeDeserializer], Any],
286+
) -> tuple[Any, EventReceiver[Any]]:
287+
transport_response = await response_future
288+
${6C|}
289+
return await awaitable_output, receiver
255290
""",
256291
pluginSymbol,
257292
transportRequest,
258293
transportResponse,
259294
configSymbol,
260295
writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)),
261-
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
262-
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
296+
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)));
263297
}
264298
writer.addStdlibImport("typing", "Any");
265299
writer.write(
@@ -899,7 +933,6 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
899933

900934
if (inputStreamSymbol != null) {
901935
if (outputStreamSymbol != null) {
902-
writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream");
903936
writer.write("""
904937
async def ${operationName:L}(
905938
self,
@@ -949,7 +982,6 @@ raise NotImplementedError()
949982
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
950983
}
951984
} else {
952-
writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream");
953985
writer.write("""
954986
async def ${operationName:L}(
955987
self,

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/OperationGenerator.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ public void run() {
6868
inSymbol.expectProperty(SymbolProperties.SCHEMA),
6969
outSymbol.expectProperty(SymbolProperties.SCHEMA),
7070
writer.consumer(this::writeErrorTypeRegistry),
71-
writer.consumer(this::writeAuthSchemes)
72-
);
71+
writer.consumer(this::writeAuthSchemes));
7372
}
7473

7574
private void writeErrorTypeRegistry(PythonWriter writer) {

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,4 @@ default void generateProtocolTests(GenerationContext context) {}
157157
default void wrapInputStream(GenerationContext context, PythonWriter writer) {}
158158

159159
default void wrapOutputStream(GenerationContext context, PythonWriter writer) {}
160-
161-
default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {}
162160
}

codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,12 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) {
396396
writer.addImport("smithy_json", "JSONCodec");
397397
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
398398
writer.addImport("smithy_core.types", "TimestampFormat");
399-
writer.addImport("aws_event_stream.aio", "AWSInputEventStream");
399+
writer.addImport("aws_event_stream.aio", "AWSEventPublisher");
400400
writer.write(
401401
"""
402402
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
403-
return AWSInputEventStream[Any, Any](
403+
publisher = AWSEventPublisher[Any](
404404
payload_codec=codec,
405-
awaitable_output=awaitable_output,
406405
async_writer=request_context.transport_request.body, # type: ignore
407406
)
408407
""");
@@ -415,39 +414,17 @@ public void wrapOutputStream(GenerationContext context, PythonWriter writer) {
415414
writer.addImport("smithy_json", "JSONCodec");
416415
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
417416
writer.addImport("smithy_core.types", "TimestampFormat");
418-
writer.addImport("aws_event_stream.aio", "AWSOutputEventStream");
417+
writer.addImport("aws_event_stream.aio", "AWSEventReceiver");
419418
writer.write(
420419
"""
421420
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
422-
return AWSOutputEventStream[Any, Any](
421+
receiver = AWSEventReceiver(
423422
payload_codec=codec,
424-
initial_response=output,
425-
async_reader=AsyncBytesReader(
423+
source=AsyncBytesReader(
426424
transport_response.body # type: ignore
427425
),
428426
deserializer=event_deserializer, # type: ignore
429427
)
430428
""");
431429
}
432-
433-
@Override
434-
public void wrapDuplexStream(GenerationContext context, PythonWriter writer) {
435-
writer.addDependency(SmithyPythonDependency.SMITHY_JSON);
436-
writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM);
437-
writer.addImport("smithy_json", "JSONCodec");
438-
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
439-
writer.addImport("smithy_core.types", "TimestampFormat");
440-
writer.addImport("aws_event_stream.aio", "AWSDuplexEventStream");
441-
writer.write(
442-
"""
443-
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
444-
return AWSDuplexEventStream[Any, Any, Any](
445-
payload_codec=codec,
446-
async_writer=request_context.transport_request.body, # type: ignore
447-
awaitable_output=awaitable_output,
448-
awaitable_response=response_future,
449-
deserializer=event_deserializer, # type: ignore
450-
)
451-
""");
452-
}
453430
}

0 commit comments

Comments
 (0)