Skip to content

Commit a32bf5c

Browse files
foo
1 parent c8645f8 commit a32bf5c

File tree

1 file changed

+83
-16
lines changed
  • packages/smithy-core/src/smithy_core/aio

1 file changed

+83
-16
lines changed

packages/smithy-core/src/smithy_core/aio/client.py

Lines changed: 83 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33
import logging
4-
from asyncio import sleep
4+
import asyncio
5+
from asyncio import sleep, Future
56
from dataclasses import dataclass, replace
6-
from typing import Any
7+
from typing import Any, Callable, Awaitable
78

89
from .interfaces import ClientProtocol, Request, Response, ClientTransport
10+
from .interfaces.eventstream import DuplexEventStream, InputEventStream, OutputEventStream, AsyncEventReceiver
911
from .. import URI
1012
from ..interfaces import TypedProperties, Endpoint
1113
from ..interfaces.retries import RetryStrategy, RetryErrorInfo, RetryErrorType
@@ -19,7 +21,7 @@
1921
from ..schemas import APIOperation
2022
from ..shapes import ShapeID
2123
from ..serializers import SerializeableShape
22-
from ..deserializers import DeserializeableShape
24+
from ..deserializers import DeserializeableShape, ShapeDeserializer
2325
from ..exceptions import SmithyRetryException, SmithyException
2426
from ..types import PropertyKey
2527

@@ -85,12 +87,65 @@ def __init__(
8587
async def __call__[I: SerializeableShape, O: DeserializeableShape](
8688
self, call: ClientCall[I, O], /
8789
) -> O:
90+
output, _ = await self._execute_request(call, None)
91+
return output
92+
93+
# async def input_stream[I: SerializeableShape, O: DeserializeableShape, E: SerializeableShape](
94+
# self, call: ClientCall[I, O], event_serializer: E, /
95+
# ) -> InputEventStream[E, O]:
96+
# request_future = Future[RequestContext[I, TRequest]]()
97+
# execute_task = asyncio.create_task(self._await_output(self._execute_request(call, request_future)))
98+
# request_context = await request_future
99+
# stream: InputEventStream[E, O] = InputEventStream(
100+
# awaitable_output=execute_task,
101+
# event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
102+
# )
103+
# return stream
104+
105+
# async def _await_output[I: SerializeableShape, O: DeserializeableShape](
106+
# self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]]
107+
# ) -> O:
108+
# output, _ = await execute_task
109+
# return output
110+
111+
# async def output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
112+
# self, call: ClientCall[I, O], event_deserializer: Callable[[ShapeDeserializer], E], /
113+
# ) -> OutputEventStream[E, O]:
114+
# output, output_context = await self._execute_request(call, None)
115+
# stream: OutputEventStream[E, O] = OutputEventStream(
116+
# output=output,
117+
# event_receiver=self.protocol.create_event_receiver(event_deserializer, output_context)
118+
# )
119+
# return stream
120+
121+
# async def duplex_stream[I: SerializeableShape, O: DeserializeableShape, IE: SerializeableShape, OE: DeserializeableShape](
122+
# self, call: ClientCall[I, O], event_serializer: IE, event_deserializer: Callable[[ShapeDeserializer], OE], /
123+
# ) -> DuplexEventStream[IE, OE, O]:
124+
# request_future = Future[RequestContext[I, TRequest]]()
125+
# execute_task = asyncio.create_task(self._execute_request(call, request_future))
126+
# request_context = await request_future
127+
# stream: DuplexEventStream[IE, OE, O] = DuplexEventStream(
128+
# awaitable_output=self._await_output_stream(execute_task, event_deserializer),
129+
# event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
130+
# )
131+
# return stream
132+
133+
# async def _await_output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
134+
# self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]],
135+
# event_deserializer: Callable[[ShapeDeserializer], E],
136+
# ) -> tuple[O, AsyncEventReceiver[E]]:
137+
# output, output_context = await execute_task
138+
# return output, self.protocol.create_event_receiver(event_deserializer, output_context)
139+
140+
async def _execute_request[I: SerializeableShape, O: DeserializeableShape](
141+
self, call: ClientCall[I, O], request_future: Future[RequestContext[I, TRequest]] | None
142+
) -> tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]:
88143
_LOGGER.debug(
89144
'Making request for operation "%s" with parameters: %s',
90145
call.operation.schema.id.name,
91146
call.input,
92147
)
93-
output_context = await self._handle_execution(call)
148+
output_context = await self._handle_execution(call, None)
94149
output_context = self._finalize_execution(call, output_context)
95150

96151
if isinstance(output_context.response, Exception):
@@ -99,13 +154,10 @@ async def __call__[I: SerializeableShape, O: DeserializeableShape](
99154
raise SmithyException(e) from e
100155
raise e
101156

102-
# TODO: wrap event streams
103-
# This needs to go on the protocols
104-
105-
return output_context.response
157+
return output_context.response, output_context
106158

107159
async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
108-
self, call: ClientCall[I, O]
160+
self, call: ClientCall[I, O], request_future: Future[RequestContext[I, TRequest]] | None
109161
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
110162
try:
111163
interceptor = call.interceptor
@@ -163,7 +215,7 @@ async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
163215
transport_request=interceptor.modify_before_retry_loop(request_context),
164216
)
165217

166-
return await self._retry(call, request_context)
218+
return await self._retry(call, request_context, request_future)
167219
except Exception as e:
168220
return OutputContext(
169221
request=request_context.request,
@@ -174,7 +226,7 @@ async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
174226
)
175227

176228
async def _retry[I: SerializeableShape, O: DeserializeableShape](
177-
self, call: ClientCall[I, O], request_context: RequestContext[I, TRequest]
229+
self, call: ClientCall[I, O], request_context: RequestContext[I, TRequest], request_future: Future[RequestContext[I, TRequest]] | None
178230
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
179231
# 8. Invoke AcquireInitialRetryToken
180232
retry_strategy = call.retry_strategy
@@ -187,7 +239,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
187239
if retry_token.retry_delay:
188240
await sleep(retry_token.retry_delay)
189241

190-
output_context = await self._handle_attempt(call, request_context)
242+
output_context = await self._handle_attempt(call, request_context, request_future)
191243

192244
if isinstance(output_context.response, Exception):
193245
try:
@@ -217,7 +269,7 @@ async def _retry[I: SerializeableShape, O: DeserializeableShape](
217269
return output_context
218270

219271
async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
220-
self, call: ClientCall[I, O], request_context: RequestContext[I, TRequest]
272+
self, call: ClientCall[I, O], request_context: RequestContext[I, TRequest], request_future: Future[RequestContext[I, TRequest]] | None
221273
) -> OutputContext[I, O, TRequest, TResponse | None]:
222274
output_context: OutputContext[I, O, TRequest, TResponse | None]
223275
try:
@@ -294,9 +346,24 @@ async def _handle_attempt[I: SerializeableShape, O: DeserializeableShape](
294346

295347
_LOGGER.debug("Sending request %s", request_context.transport_request)
296348

297-
transport_response = await self.transport.send(
298-
request=request_context.transport_request
299-
)
349+
if request_future is not None:
350+
# If we have an input event stream (or duplex event stream) then we
351+
# need to let the client return ASAP so that it can start sending
352+
# events. So here we start the transport send in a background task
353+
# then set the result of the request future. It's important to sequence
354+
# it just like that so that the client gets a stream that's ready
355+
# to send.
356+
transport_task = asyncio.create_task(self.transport.send(
357+
request=request_context.transport_request
358+
))
359+
request_future.set_result(request_context)
360+
transport_response = await transport_task
361+
else:
362+
# If we don't have an input stream, there's no point in creating a
363+
# task, so we just immediately await the coroutine.
364+
transport_response = await self.transport.send(
365+
request=request_context.transport_request
366+
)
300367

301368
_LOGGER.debug("Received response: %s", transport_response)
302369

0 commit comments

Comments
 (0)