@@ -120,10 +120,7 @@ private void generateOperationExecutor(PythonWriter writer) {
120120 var hasStreaming = hasEventStream ();
121121 writer .putContext ("hasEventStream" , hasStreaming );
122122 if (hasStreaming ) {
123- writer .addImports ("smithy_core.deserializers" ,
124- Set .of (
125- "ShapeDeserializer" ,
126- "DeserializeableShape" ));
123+ writer .addImport ("smithy_core.deserializers" , "ShapeDeserializer" );
127124 writer .addStdlibImport ("typing" , "Any" );
128125 }
129126
@@ -137,7 +134,8 @@ private void generateOperationExecutor(PythonWriter writer) {
137134 writer .addStdlibImport ("typing" , "Awaitable" );
138135 writer .addStdlibImport ("typing" , "cast" );
139136 writer .addStdlibImport ("copy" , "deepcopy" );
140- writer .addStdlibImport ("asyncio" , "sleep" );
137+ writer .addStdlibImport ("asyncio" );
138+ writer .addStdlibImports ("asyncio" , Set .of ("sleep" , "Future" ));
141139
142140 writer .addDependency (SmithyPythonDependency .SMITHY_CORE );
143141 writer .addImport ("smithy_core.exceptions" , "SmithyRetryException" );
@@ -187,6 +185,75 @@ def _classify_error(
187185 """ );
188186 writer .dedent ();
189187
188+ if (hasStreaming ) {
189+ writer .addStdlibImports ("typing" , Set .of ("Any" , "Awaitable" ));
190+ writer .addStdlibImport ("asyncio" );
191+ writer .write (
192+ """
193+ async def _input_stream(
194+ self,
195+ input: Input,
196+ plugins: list[$1T],
197+ serialize: Callable[[Input, $4T], Awaitable[$2T]],
198+ deserialize: Callable[[$3T, $4T], Awaitable[Output]],
199+ config: $4T,
200+ operation_name: str,
201+ ) -> Any:
202+ request_future = Future[$2T]()
203+ awaitable_output = asyncio.create_task(self._execute_operation(
204+ input, plugins, serialize, deserialize, config, operation_name,
205+ request_future=request_future
206+ ))
207+ transport_request = await request_future
208+ ${5C|}
209+
210+ async def _output_stream(
211+ self,
212+ input: Input,
213+ plugins: list[$1T],
214+ serialize: Callable[[Input, $4T], Awaitable[$2T]],
215+ deserialize: Callable[[$3T, $4T], Awaitable[Output]],
216+ config: $4T,
217+ operation_name: str,
218+ event_deserializer: Callable[[ShapeDeserializer], Any],
219+ ) -> Any:
220+ response_future = Future[$3T]()
221+ output = await self._execute_operation(
222+ input, plugins, serialize, deserialize, config, operation_name,
223+ response_future=response_future
224+ )
225+ transport_response = await response_future
226+ ${6C|}
227+
228+ async def _duplex_stream(
229+ self,
230+ input: Input,
231+ plugins: list[$1T],
232+ serialize: Callable[[Input, $4T], Awaitable[$2T]],
233+ deserialize: Callable[[$3T, $4T], Awaitable[Output]],
234+ config: $4T,
235+ operation_name: str,
236+ event_deserializer: Callable[[ShapeDeserializer], Any],
237+ ) -> Any:
238+ request_future = Future[$2T]()
239+ response_future = Future[$3T]()
240+ awaitable_output = asyncio.create_task(self._execute_operation(
241+ input, plugins, serialize, deserialize, config, operation_name,
242+ request_future=request_future,
243+ response_future=response_future
244+ ))
245+ transport_request = await request_future
246+ ${7C|}
247+ """ ,
248+ pluginSymbol ,
249+ transportRequest ,
250+ transportResponse ,
251+ configSymbol ,
252+ writer .consumer (w -> context .protocolGenerator ().wrapInputStream (context , w )),
253+ writer .consumer (w -> context .protocolGenerator ().wrapOutputStream (context , w )),
254+ writer .consumer (w -> context .protocolGenerator ().wrapDuplexStream (context , w )));
255+ }
256+
190257 writer .write (
191258 """
192259 async def _execute_operation(
@@ -197,25 +264,25 @@ async def _execute_operation(
197264 deserialize: Callable[[$3T, $5T], Awaitable[Output]],
198265 config: $5T,
199266 operation_name: str,
200- ${?hasEventStream}
201- has_input_stream: bool = False,
202- event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
203- event_response_deserializer: DeserializeableShape | None = None,
204- ${/hasEventStream}
267+ request_future: Future[$2T] | None = None,
268+ response_future: Future[$3T] | None = None,
205269 ) -> Output:
206270 try:
207271 return await self._handle_execution(
208272 input, plugins, serialize, deserialize, config, operation_name,
209- ${?hasEventStream}
210- has_input_stream, event_deserializer, event_response_deserializer,
211- ${/hasEventStream}
273+ request_future, response_future,
212274 )
213275 except Exception as e:
276+ if request_future is not None and not request_future.done:
277+ request_future.set_exception($4T(e))
278+ if response_future is not None and not response_future.done:
279+ response_future.set_exception($4T(e))
280+
214281 # Make sure every exception that we throw is an instance of $4T so
215282 # customers can reliably catch everything we throw.
216283 if not isinstance(e, $4T):
217284 raise $4T(e) from e
218- raise e
285+ raise
219286
220287 async def _handle_execution(
221288 self,
@@ -225,11 +292,8 @@ async def _handle_execution(
225292 deserialize: Callable[[$3T, $5T], Awaitable[Output]],
226293 config: $5T,
227294 operation_name: str,
228- ${?hasEventStream}
229- has_input_stream: bool = False,
230- event_deserializer: Callable[[ShapeDeserializer], Any] | None = None,
231- event_response_deserializer: DeserializeableShape | None = None,
232- ${/hasEventStream}
295+ request_future: Future[$2T] | None,
296+ response_future: Future[$3T] | None,
233297 ) -> Output:
234298 logger.debug('Making request for operation "%s" with parameters: %s', operation_name, input)
235299 context: InterceptorContext[Input, None, None, None] = InterceptorContext(
@@ -307,6 +371,7 @@ async def _handle_execution(
307371 context_with_transport_request.copy(),
308372 config,
309373 operation_name,
374+ request_future,
310375 )
311376
312377 # We perform this type-ignored re-assignment because `context` needs
@@ -342,6 +407,10 @@ await seek(0)
342407 else:
343408 # Step 8: Invoke record_success
344409 retry_strategy.record_success(token=retry_token)
410+ if response_future is not None:
411+ response_future.set_result(
412+ context_with_response.response, # type: ignore
413+ )
345414 break
346415 except Exception as e:
347416 if context.response is not None:
@@ -355,16 +424,7 @@ await seek(0)
355424 execution_context = cast(
356425 InterceptorContext[Input, Output, $2T | None, $3T | None], context
357426 )
358- ${^hasEventStream}
359427 return await self._finalize_execution(interceptors, execution_context)
360- ${/hasEventStream}
361- ${?hasEventStream}
362- operation_output = await self._finalize_execution(interceptors, execution_context)
363- if has_input_stream or event_deserializer is not None:
364- ${6C|}
365- else:
366- return operation_output
367- ${/hasEventStream}
368428
369429 async def _handle_attempt(
370430 self,
@@ -373,6 +433,7 @@ async def _handle_attempt(
373433 context: InterceptorContext[Input, None, $2T, None],
374434 config: $5T,
375435 operation_name: str,
436+ request_future: Future[$2T] | None,
376437 ) -> InterceptorContext[Input, Output, $2T, $3T | None]:
377438 try:
378439 # assert config.interceptors is not None
@@ -385,8 +446,7 @@ async def _handle_attempt(
385446 transportRequest ,
386447 transportResponse ,
387448 errorSymbol ,
388- configSymbol ,
389- writer .consumer (w -> context .protocolGenerator ().wrapEventStream (context , w )));
449+ configSymbol );
390450
391451 boolean supportsAuth = !ServiceIndex .of (context .model ()).getAuthSchemes (service ).isEmpty ();
392452 writer .pushState (new ResolveIdentitySection ());
@@ -533,10 +593,19 @@ async def _handle_attempt(
533593 )
534594 logger.debug("HTTP request config: %s", request_config)
535595 logger.debug("Sending HTTP request: %s", context_with_response.transport_request)
536- context_with_response._transport_response = await config.http_client.send(
537- request=context_with_response.transport_request,
538- request_config=request_config,
539- )
596+
597+ if request_future is not None:
598+ response_task = asyncio.create_task(config.http_client.send(
599+ request=context_with_response.transport_request,
600+ request_config=request_config,
601+ ))
602+ request_future.set_result(context_with_response.transport_request)
603+ context_with_response._transport_response = await response_task
604+ else:
605+ context_with_response._transport_response = await config.http_client.send(
606+ request=context_with_response.transport_request,
607+ request_config=request_config,
608+ )
540609 logger.debug("Received HTTP response: %s", context_with_response.transport_response)
541610
542611 """ , transportRequest , transportResponse );
@@ -834,16 +903,14 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
834903 raise NotImplementedError()
835904 ${/hasProtocol}
836905 ${?hasProtocol}
837- return await self._execute_operation (
906+ return await self._duplex_stream (
838907 input=input,
839908 plugins=operation_plugins,
840909 serialize=${serSymbol:T},
841910 deserialize=${deserSymbol:T},
842911 config=self._config,
843912 operation_name=${operationName:S},
844- has_input_stream=True,
845913 event_deserializer=$T().deserialize,
846- event_response_deserializer=${output:T},
847914 ) # type: ignore
848915 ${/hasProtocol}
849916 """ ,
@@ -862,14 +929,13 @@ raise NotImplementedError()
862929 raise NotImplementedError()
863930 ${/hasProtocol}
864931 ${?hasProtocol}
865- return await self._execute_operation (
932+ return await self._input_stream (
866933 input=input,
867934 plugins=operation_plugins,
868935 serialize=${serSymbol:T},
869936 deserialize=${deserSymbol:T},
870937 config=self._config,
871938 operation_name=${operationName:S},
872- has_input_stream=True,
873939 ) # type: ignore
874940 ${/hasProtocol}
875941 """ , writer .consumer (w -> writeSharedOperationInit (w , operation , input )));
@@ -887,15 +953,14 @@ raise NotImplementedError()
887953 raise NotImplementedError()
888954 ${/hasProtocol}
889955 ${?hasProtocol}
890- return await self._execute_operation (
956+ return await self._output_stream (
891957 input=input,
892958 plugins=operation_plugins,
893959 serialize=${serSymbol:T},
894960 deserialize=${deserSymbol:T},
895961 config=self._config,
896962 operation_name=${operationName:S},
897963 event_deserializer=$T().deserialize,
898- event_response_deserializer=${output:T},
899964 ) # type: ignore
900965 ${/hasProtocol}
901966 """ ,
0 commit comments