11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33import logging
4- from asyncio import sleep
4+ import asyncio
5+ from asyncio import sleep , Future
56from dataclasses import dataclass , replace
6- from typing import Any
7+ from typing import Any , Callable , Awaitable
78
89from .interfaces import ClientProtocol , Request , Response , ClientTransport
10+ from .interfaces .eventstream import DuplexEventStream , InputEventStream , OutputEventStream , AsyncEventReceiver
911from .. import URI
1012from ..interfaces import TypedProperties , Endpoint
1113from ..interfaces .retries import RetryStrategy , RetryErrorInfo , RetryErrorType
1921from ..schemas import APIOperation
2022from ..shapes import ShapeID
2123from ..serializers import SerializeableShape
22- from ..deserializers import DeserializeableShape
24+ from ..deserializers import DeserializeableShape , ShapeDeserializer
2325from ..exceptions import SmithyRetryException , SmithyException
2426from ..types import PropertyKey
2527
@@ -85,12 +87,65 @@ def __init__(
8587 async def __call__ [I : SerializeableShape , O : DeserializeableShape ](
8688 self , call : ClientCall [I , O ], /
8789 ) -> O :
90+ output , _ = await self ._execute_request (call , None )
91+ return output
92+
93+ # async def input_stream[I: SerializeableShape, O: DeserializeableShape, E: SerializeableShape](
94+ # self, call: ClientCall[I, O], event_serializer: E, /
95+ # ) -> InputEventStream[E, O]:
96+ # request_future = Future[RequestContext[I, TRequest]]()
97+ # execute_task = asyncio.create_task(self._await_output(self._execute_request(call, request_future)))
98+ # request_context = await request_future
99+ # stream: InputEventStream[E, O] = InputEventStream(
100+ # awaitable_output=execute_task,
101+ # event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
102+ # )
103+ # return stream
104+
105+ # async def _await_output[I: SerializeableShape, O: DeserializeableShape](
106+ # self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]]
107+ # ) -> O:
108+ # output, _ = await execute_task
109+ # return output
110+
111+ # async def output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
112+ # self, call: ClientCall[I, O], event_deserializer: Callable[[ShapeDeserializer], E], /
113+ # ) -> OutputEventStream[E, O]:
114+ # output, output_context = await self._execute_request(call, None)
115+ # stream: OutputEventStream[E, O] = OutputEventStream(
116+ # output=output,
117+ # event_receiver=self.protocol.create_event_receiver(event_deserializer, output_context)
118+ # )
119+ # return stream
120+
121+ # async def duplex_stream[I: SerializeableShape, O: DeserializeableShape, IE: SerializeableShape, OE: DeserializeableShape](
122+ # self, call: ClientCall[I, O], event_serializer: IE, event_deserializer: Callable[[ShapeDeserializer], OE], /
123+ # ) -> DuplexEventStream[IE, OE, O]:
124+ # request_future = Future[RequestContext[I, TRequest]]()
125+ # execute_task = asyncio.create_task(self._execute_request(call, request_future))
126+ # request_context = await request_future
127+ # stream: DuplexEventStream[IE, OE, O] = DuplexEventStream(
128+ # awaitable_output=self._await_output_stream(execute_task, event_deserializer),
129+ # event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
130+ # )
131+ # return stream
132+
133+ # async def _await_output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
134+ # self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]],
135+ # event_deserializer: Callable[[ShapeDeserializer], E],
136+ # ) -> tuple[O, AsyncEventReceiver[E]]:
137+ # output, output_context = await execute_task
138+ # return output, self.protocol.create_event_receiver(event_deserializer, output_context)
139+
140+ async def _execute_request [I : SerializeableShape , O : DeserializeableShape ](
141+ self , call : ClientCall [I , O ], request_future : Future [RequestContext [I , TRequest ]] | None
142+ ) -> tuple [O , OutputContext [I , O , TRequest | None , TResponse | None ]]:
88143 _LOGGER .debug (
89144 'Making request for operation "%s" with parameters: %s' ,
90145 call .operation .schema .id .name ,
91146 call .input ,
92147 )
93- output_context = await self ._handle_execution (call )
148+ output_context = await self ._handle_execution (call , None )
94149 output_context = self ._finalize_execution (call , output_context )
95150
96151 if isinstance (output_context .response , Exception ):
@@ -99,13 +154,10 @@ async def __call__[I: SerializeableShape, O: DeserializeableShape](
99154 raise SmithyException (e ) from e
100155 raise e
101156
102- # TODO: wrap event streams
103- # This needs to go on the protocols
104-
105- return output_context .response
157+ return output_context .response , output_context
106158
107159 async def _handle_execution [I : SerializeableShape , O : DeserializeableShape ](
108- self , call : ClientCall [I , O ]
160+ self , call : ClientCall [I , O ], request_future : Future [ RequestContext [ I , TRequest ]] | None
109161 ) -> OutputContext [I , O , TRequest | None , TResponse | None ]:
110162 try :
111163 interceptor = call .interceptor
@@ -163,7 +215,7 @@ async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
163215 transport_request = interceptor .modify_before_retry_loop (request_context ),
164216 )
165217
166- return await self ._retry (call , request_context )
218+ return await self ._retry (call , request_context , request_future )
167219 except Exception as e :
168220 return OutputContext (
169221 request = request_context .request ,
@@ -174,7 +226,7 @@ async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
174226 )
175227
176228 async def _retry [I : SerializeableShape , O : DeserializeableShape ](
177- self , call : ClientCall [I , O ], request_context : RequestContext [I , TRequest ]
229+ self , call : ClientCall [I , O ], request_context : RequestContext [I , TRequest ], request_future : Future [ RequestContext [ I , TRequest ]] | None
178230 ) -> OutputContext [I , O , TRequest | None , TResponse | None ]:
179231 # 8. Invoke AcquireInitialRetryToken
180232 retry_strategy = call .retry_strategy
@@ -187,7 +239,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
187239 if retry_token .retry_delay :
188240 await sleep (retry_token .retry_delay )
189241
190- output_context = await self ._handle_attempt (call , request_context )
242+ output_context = await self ._handle_attempt (call , request_context , request_future )
191243
192244 if isinstance (output_context .response , Exception ):
193245 try :
@@ -217,7 +269,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
217269 return output_context
218270
219271 async def _handle_attempt [I : SerializeableShape , O : DeserializeableShape ](
220- self , call : ClientCall [I , O ], request_context : RequestContext [I , TRequest ]
272+ self , call : ClientCall [I , O ], request_context : RequestContext [I , TRequest ], request_future : Future [ RequestContext [ I , TRequest ]] | None
221273 ) -> OutputContext [I , O , TRequest , TResponse | None ]:
222274 output_context : OutputContext [I , O , TRequest , TResponse | None ]
223275 try :
@@ -294,9 +346,24 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
294346
295347 _LOGGER .debug ("Sending request %s" , request_context .transport_request )
296348
297- transport_response = await self .transport .send (
298- request = request_context .transport_request
299- )
349+ if request_future is not None :
350+ # If we have an input event stream (or duplex event stream) then we
351+ # need to let the client return ASAP so that it can start sending
352+ # events. So here we start the transport send in a background task
353+ # then set the result of the request future. It's important to sequence
354+ # it just like that so that the client gets a stream that's ready
355+ # to send.
356+ transport_task = asyncio .create_task (self .transport .send (
357+ request = request_context .transport_request
358+ ))
359+ request_future .set_result (request_context )
360+ transport_response = await transport_task
361+ else :
362+ # If we don't have an input stream, there's no point in creating a
363+ # task, so we just immediately await the coroutine.
364+ transport_response = await self .transport .send (
365+ request = request_context .transport_request
366+ )
300367
301368 _LOGGER .debug ("Received response: %s" , transport_response )
302369
0 commit comments