Skip to content

Commit fe01293

Browse files
Handle event streams in request pipeline
1 parent f81a2ec commit fe01293

File tree

1 file changed

+117
-53
lines changed
  • packages/smithy-core/src/smithy_core/aio

1 file changed

+117
-53
lines changed

packages/smithy-core/src/smithy_core/aio/client.py

Lines changed: 117 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import asyncio
55
from asyncio import sleep, Future
66
from dataclasses import dataclass, replace
7-
from typing import Any
7+
from typing import Any, Awaitable, Callable
88

99
from .interfaces import ClientProtocol, Request, Response, ClientTransport
10+
from .interfaces.eventstream import EventReceiver
11+
from .eventstream import InputEventStream, OutputEventStream, DuplexEventStream
1012
from .. import URI
1113
from ..interfaces import TypedProperties, Endpoint
1214
from ..interfaces.retries import RetryStrategy, RetryErrorInfo, RetryErrorType
@@ -20,7 +22,7 @@
2022
from ..schemas import APIOperation
2123
from ..shapes import ShapeID
2224
from ..serializers import SerializeableShape
23-
from ..deserializers import DeserializeableShape
25+
from ..deserializers import DeserializeableShape, ShapeDeserializer
2426
from ..exceptions import SmithyRetryException, SmithyException
2527
from ..types import PropertyKey
2628

@@ -86,61 +88,123 @@ def __init__(
8688
async def __call__[I: SerializeableShape, O: DeserializeableShape](
8789
self, call: ClientCall[I, O], /
8890
) -> O:
89-
output, _ = await self._execute_request(call, None)
90-
return output
91-
92-
# async def input_stream[I: SerializeableShape, O: DeserializeableShape, E: SerializeableShape](
93-
# self, call: ClientCall[I, O], event_serializer: E, /
94-
# ) -> InputEventStream[E, O]:
95-
# request_future = Future[RequestContext[I, TRequest]]()
96-
# execute_task = asyncio.create_task(self._await_output(self._execute_request(call, request_future)))
97-
# request_context = await request_future
98-
# stream: InputEventStream[E, O] = InputEventStream(
99-
# awaitable_output=execute_task,
100-
# event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
101-
# )
102-
# return stream
103-
104-
# async def _await_output[I: SerializeableShape, O: DeserializeableShape](
105-
# self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]]
106-
# ) -> O:
107-
# output, _ = await execute_task
108-
# return output
109-
110-
# async def output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
111-
# self, call: ClientCall[I, O], event_deserializer: Callable[[ShapeDeserializer], E], /
112-
# ) -> OutputEventStream[E, O]:
113-
# output, output_context = await self._execute_request(call, None)
114-
# stream: OutputEventStream[E, O] = OutputEventStream(
115-
# output=output,
116-
# event_receiver=self.protocol.create_event_receiver(event_deserializer, output_context)
117-
# )
118-
# return stream
119-
120-
# async def duplex_stream[I: SerializeableShape, O: DeserializeableShape, IE: SerializeableShape, OE: DeserializeableShape](
121-
# self, call: ClientCall[I, O], event_serializer: IE, event_deserializer: Callable[[ShapeDeserializer], OE], /
122-
# ) -> DuplexEventStream[IE, OE, O]:
123-
# request_future = Future[RequestContext[I, TRequest]]()
124-
# execute_task = asyncio.create_task(self._execute_request(call, request_future))
125-
# request_context = await request_future
126-
# stream: DuplexEventStream[IE, OE, O] = DuplexEventStream(
127-
# awaitable_output=self._await_output_stream(execute_task, event_deserializer),
128-
# event_publisher=self.protocol.create_event_publisher(event_serializer, request_context)
129-
# )
130-
# return stream
131-
132-
# async def _await_output_stream[I: SerializeableShape, O: DeserializeableShape, E: DeserializeableShape](
133-
# self, execute_task: Awaitable[tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]],
134-
# event_deserializer: Callable[[ShapeDeserializer], E],
135-
# ) -> tuple[O, AsyncEventReceiver[E]]:
136-
# output, output_context = await execute_task
137-
# return output, self.protocol.create_event_receiver(event_deserializer, output_context)
91+
output_context = await self._execute_request(call, None)
92+
return output_context.response # type: ignore
93+
94+
async def input_stream[
95+
I: SerializeableShape,
96+
O: DeserializeableShape,
97+
E: SerializeableShape,
98+
](self, call: ClientCall[I, O], event_type: type[E], /) -> InputEventStream[E, O]:
99+
request_future = Future[RequestContext[I, TRequest]]()
100+
execute_task = asyncio.create_task(
101+
self._await_output(self._execute_request(call, request_future))
102+
)
103+
request_context = await request_future
104+
input_stream = self.protocol.create_event_publisher(
105+
operation=call.operation,
106+
request=request_context.transport_request,
107+
event_type=event_type,
108+
context=request_context.properties,
109+
)
110+
stream: InputEventStream[E, O] = InputEventStream(
111+
input_stream=input_stream, output_future=execute_task
112+
)
113+
return stream
114+
115+
async def _await_output[I: SerializeableShape, O: DeserializeableShape](
116+
self,
117+
execute_task: Awaitable[OutputContext[I, O, TRequest | None, TResponse | None]],
118+
) -> O:
119+
output_context = await execute_task
120+
return output_context.response # type: ignore
121+
122+
async def output_stream[
123+
I: SerializeableShape,
124+
O: DeserializeableShape,
125+
E: DeserializeableShape,
126+
](
127+
self,
128+
call: ClientCall[I, O],
129+
event_type: type[E],
130+
event_deserializer: Callable[[ShapeDeserializer], E],
131+
/,
132+
) -> OutputEventStream[E, O]:
133+
output_context = await self._execute_request(call, None)
134+
output_stream = self.protocol.create_event_receiver(
135+
operation=call.operation,
136+
request=output_context.transport_request, # type: ignore
137+
response=output_context.transport_response, # type: ignore
138+
event_type=event_type,
139+
event_deserializer=event_deserializer,
140+
context=output_context.properties,
141+
)
142+
stream: OutputEventStream[E, O] = OutputEventStream(
143+
output_stream=output_stream,
144+
output=output_context.response, # type: ignore
145+
)
146+
return stream
147+
148+
async def duplex_stream[
149+
I: SerializeableShape,
150+
O: DeserializeableShape,
151+
IE: SerializeableShape,
152+
OE: DeserializeableShape,
153+
](
154+
self,
155+
call: ClientCall[I, O],
156+
input_event_type: type[IE],
157+
event_serializer: IE,
158+
output_event_type: type[OE],
159+
event_deserializer: Callable[[ShapeDeserializer], OE],
160+
/,
161+
) -> DuplexEventStream[IE, OE, O]:
162+
request_future = Future[RequestContext[I, TRequest]]()
163+
execute_task = asyncio.create_task(self._execute_request(call, request_future))
164+
request_context = await request_future
165+
input_stream = self.protocol.create_event_publisher(
166+
operation=call.operation,
167+
request=request_context.transport_request,
168+
event_type=input_event_type,
169+
context=request_context.properties,
170+
)
171+
output_future = asyncio.create_task(
172+
self._await_output_stream(
173+
call=call,
174+
execute_task=execute_task,
175+
output_event_type=output_event_type,
176+
event_deserializer=event_deserializer,
177+
)
178+
)
179+
return DuplexEventStream(input_stream=input_stream, output_future=output_future)
180+
181+
async def _await_output_stream[
182+
I: SerializeableShape,
183+
O: DeserializeableShape,
184+
OE: DeserializeableShape,
185+
](
186+
self,
187+
call: ClientCall[I, O],
188+
execute_task: Awaitable[OutputContext[I, O, TRequest | None, TResponse | None]],
189+
output_event_type: type[OE],
190+
event_deserializer: Callable[[ShapeDeserializer], OE],
191+
) -> tuple[O, EventReceiver[OE]]:
192+
output_context = await execute_task
193+
output_stream = self.protocol.create_event_receiver(
194+
operation=call.operation,
195+
request=output_context.transport_request, # type: ignore
196+
response=output_context.transport_response, # type: ignore
197+
event_type=output_event_type,
198+
event_deserializer=event_deserializer,
199+
context=output_context.properties,
200+
)
201+
return output_context.response, output_stream # type: ignore
138202

139203
async def _execute_request[I: SerializeableShape, O: DeserializeableShape](
140204
self,
141205
call: ClientCall[I, O],
142206
request_future: Future[RequestContext[I, TRequest]] | None,
143-
) -> tuple[O, OutputContext[I, O, TRequest | None, TResponse | None]]:
207+
) -> OutputContext[I, O, TRequest | None, TResponse | None]:
144208
_LOGGER.debug(
145209
'Making request for operation "%s" with parameters: %s',
146210
call.operation.schema.id.name,
@@ -155,7 +219,7 @@ async def _execute_request[I: SerializeableShape, O: DeserializeableShape](
155219
raise SmithyException(e) from e
156220
raise e
157221

158-
return output_context.response, output_context
222+
return output_context
159223

160224
async def _handle_execution[I: SerializeableShape, O: DeserializeableShape](
161225
self,

0 commit comments

Comments
 (0)