diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index 1f4ca8a0b..f04b7e50e 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -3,11 +3,16 @@ import datetime from collections.abc import Callable +from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable from smithy_core.codecs import Codec -from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer -from smithy_core.interfaces import BytesReader +from smithy_core.deserializers import ( + DeserializeableShape, + ShapeDeserializer, + SpecificShapeDeserializer, +) from smithy_core.schemas import Schema from smithy_core.utils import expect_type +from smithy_event_stream.aio.interfaces import AsyncEventReceiver from ..events import HEADERS_DICT, Event from ..exceptions import EventError, UnmodeledEventError @@ -17,11 +22,38 @@ INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE) -class EventDeserializer(SpecificShapeDeserializer): +class AWSAsyncEventReceiver[E: DeserializeableShape](AsyncEventReceiver[E]): def __init__( - self, source: BytesReader, payload_codec: Codec, is_client_mode: bool = True + self, + payload_codec: Codec, + source: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], E], + is_client_mode: bool = True, ) -> None: + self._payload_codec = payload_codec self._source = source + self._is_client_mode = is_client_mode + self._deserializer = deserializer + + async def receive(self) -> E | None: + event = await Event.decode_async(self._source) + deserializer = EventDeserializer( + event=event, + payload_codec=self._payload_codec, + is_client_mode=self._is_client_mode, + ) + return self._deserializer(deserializer) + + async def close(self) -> None: + if isinstance(self._source, AsyncCloseable): + await self._source.close() + + +class EventDeserializer(SpecificShapeDeserializer): + def __init__( + self, event: Event, payload_codec: Codec, is_client_mode: bool = True + ) -> None: + self._event = event self._payload_codec = payload_codec self._is_client_mode = is_client_mode @@ -30,13 +62,12 @@ def read_struct( schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None], ) -> None: - event = Event.decode(self._source) - headers = event.message.headers + headers = self._event.message.headers payload_deserializer = None - if event.message.payload: + if self._event.message.payload: payload_deserializer = self._payload_codec.create_deserializer( - event.message.payload + self._event.message.payload ) message_deserializer = EventMessageDeserializer(headers, payload_deserializer) @@ -61,7 +92,7 @@ def read_struct( expect_type(str, headers[":error-message"]), ) case _: - raise EventError(f"Unknown event structure: {event}") + raise EventError(f"Unknown event structure: {self._event}") class EventMessageDeserializer(SpecificShapeDeserializer): diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py index 276488c47..3baf1c0e7 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py @@ -1,20 +1,24 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import datetime -from collections.abc import Iterator +from collections.abc import Callable, Iterator from contextlib import contextmanager from io import BytesIO from typing import Never +from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter from smithy_core.codecs import Codec +from smithy_core.exceptions import ExpectationNotMetException from smithy_core.schemas import Schema from smithy_core.serializers import ( InterceptingSerializer, + SerializeableShape, ShapeSerializer, SpecificShapeSerializer, ) from smithy_core.shapes import ShapeType from smithy_core.utils import expect_type +from smithy_event_stream.aio.interfaces import AsyncEventPublisher from ..events import EventHeaderEncoder, EventMessage from ..exceptions import InvalidHeaderValue @@ -30,6 +34,40 @@ _DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream" +type Signer = Callable[[EventMessage], EventMessage] +"""A function that takes an event message and signs it, and returns it signed.""" + + +class AWSAsyncEventPublisher[E: SerializeableShape](AsyncEventPublisher[E]): + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + signer: Signer | None = None, + is_client_mode: bool = True, + ): + self._writer = async_writer + self._signer = signer + self._serializer = EventSerializer( + payload_codec=payload_codec, is_client_mode=is_client_mode + ) + + async def send(self, event: E) -> None: + event.serialize(self._serializer) + result = self._serializer.get_result() + if result is None: + raise ExpectationNotMetException( + "Expected an event message to be serialized, but was None." + ) + if self._signer is not None: + result = self._signer(result) + await self._writer.write(result.encode()) + + async def close(self) -> None: + if isinstance(self._writer, AsyncCloseable): + await self._writer.close() + + class EventSerializer(SpecificShapeSerializer): def __init__( self, diff --git a/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py b/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py new file mode 100644 index 000000000..b019e004e --- /dev/null +++ b/python-packages/aws-event-stream/aws_event_stream/aio/__init__.py @@ -0,0 +1,247 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import asyncio +from collections.abc import Callable +from typing import Self + +from smithy_core.aio.interfaces import AsyncByteStream, AsyncWriter +from smithy_core.codecs import Codec +from smithy_core.deserializers import DeserializeableShape, ShapeDeserializer +from smithy_core.serializers import SerializeableShape +from smithy_event_stream.aio.interfaces import ( + AsyncEventReceiver, + DuplexEventStream, + InputEventStream, + OutputEventStream, +) + +from .._private.deserializers import AWSAsyncEventReceiver as _AWSEventReceiver +from .._private.serializers import AWSAsyncEventPublisher as _AWSEventPublisher +from .._private.serializers import Signer +from ..exceptions import MissingInitialResponse + + +class AWSDuplexEventStream[ + I: SerializeableShape, O: DeserializeableShape, R: DeserializeableShape +](DuplexEventStream[I, O, R]): + """A duplex event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + deserializer: Callable[[ShapeDeserializer], O], + async_reader: AsyncByteStream | None = None, + initial_response: R | None = None, + deserializeable_response: type[R] | None = None, + signer: Signer | None = None, + is_client_mode: bool = True, + ) -> None: + """Construct an AWSDuplexEventStream. + + :param payload_codec: The codec to encode the event payload with. + :param async_writer: The writer to write event bytes to. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param async_reader: The reader to read event bytes from, if available. If not + immediately available, output will be blocked on it becoming available. + :param initial_response: The deserialized operation response, if available. If + not immediately available, output will be blocked on it becoming available. + :param deserializeable_response: The deserializeable response class. Setting + this indicates that the initial response is sent over the event stream. The + deserialize method of this class will be used to deserialize it upon + calling ``await_output``. + :param signer: An optional callable to sign events with prior to them being + encoded. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self.input_stream = _AWSEventPublisher( + payload_codec=payload_codec, + async_writer=async_writer, + signer=signer, + is_client_mode=is_client_mode, + ) + + self._deserializer = deserializer + self._payload_codec = payload_codec + self._is_client_mode = is_client_mode + + # Create a future to allow awaiting the reader + loop = asyncio.get_event_loop() + self._reader_future: asyncio.Future[AsyncByteStream] = loop.create_future() + if async_reader is not None: + self._reader_future.set_result(async_reader) + + # Create a future to allow awaiting the initial response + self._response = initial_response + self._deserializerable_response = deserializeable_response + self._response_future: asyncio.Future[R] = loop.create_future() + + @property + def response(self) -> R | None: + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response_future.set_result(value) + self._response = value + + def set_reader(self, value: AsyncByteStream) -> None: + """Sets the object to read events from. + + :param value: An async readable object to read event bytes from. + """ + self._reader_future.set_result(value) + + async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: + async_reader = await self._reader_future + if self.output_stream is None: + self.output_stream = _AWSEventReceiver[O]( + payload_codec=self._payload_codec, + source=async_reader, + deserializer=self._deserializer, + is_client_mode=self._is_client_mode, + ) + + if self.response is None: + if self._deserializerable_response is None: + initial_response = await self._response_future + else: + initial_response_stream = _AWSEventReceiver( + payload_codec=self._payload_codec, + source=async_reader, + deserializer=self._deserializerable_response.deserialize, + is_client_mode=self._is_client_mode, + ) + initial_response = await initial_response_stream.receive() + if initial_response is None: + raise MissingInitialResponse() + self.response = initial_response + else: + initial_response = self.response + + return initial_response, self.output_stream + + +class AWSInputEventStream[I: SerializeableShape, R](InputEventStream[I, R]): + """An input event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + async_writer: AsyncWriter, + initial_response: R | None = None, + signer: Signer | None = None, + is_client_mode: bool = True, + ) -> None: + """Construct an AWSInputEventStream. + + :param payload_codec: The codec to encode the event payload with. + :param async_writer: The writer to write event bytes to. + :param initial_response: The deserialized operation response, if available. + :param signer: An optional callable to sign events with prior to them being + encoded. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self._response = initial_response + + # Create a future to allow awaiting the initial response. + loop = asyncio.get_event_loop() + self._response_future: asyncio.Future[R] = loop.create_future() + if initial_response is not None: + self._response_future.set_result(initial_response) + + self.input_stream = _AWSEventPublisher( + payload_codec=payload_codec, + async_writer=async_writer, + signer=signer, + is_client_mode=is_client_mode, + ) + + @property + def response(self) -> R | None: + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response_future.set_result(value) + self._response = value + + async def await_output(self) -> R: + return await self._response_future + + +class AWSOutputEventStream[O: DeserializeableShape, R: DeserializeableShape]( + OutputEventStream[O, R] +): + """An output event stream using the application/vnd.amazon.eventstream format.""" + + def __init__( + self, + payload_codec: Codec, + initial_response: R, + async_reader: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], O], + is_client_mode: bool = True, + ) -> None: + """Construct an AWSOutputEventStream. + + :param payload_codec: The codec to decode event payloads with. + :param initial_response: The deserialized operation response. If this is not + available immediately, use ``AWSOutputEventStream.create``. + :param async_reader: An async reader to read event bytes from. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + self.response = initial_response + self.output_stream = _AWSEventReceiver[O]( + payload_codec=payload_codec, + source=async_reader, + deserializer=deserializer, + is_client_mode=is_client_mode, + ) + + @classmethod + async def create( + cls, + payload_codec: Codec, + deserializeable_response: type[R], + async_reader: AsyncByteStream, + deserializer: Callable[[ShapeDeserializer], O], + is_client_mode: bool = True, + ) -> Self: + """Construct an AWSOutputEventStream and decode the initial response. + + :param payload_codec: The codec to decode event payloads with. + :param deserializeable_response: The deserializeable response class. The + deserialize method of this class will be used to deserialize the + initial response from the stream.. + :param initial_response: The deserialized operation response. If this is not + available immediately, use ``AWSOutputEventStream.create``. + :param async_reader: An async reader to read event bytes from. + :param deserializer: A callable to deserialize events with. This should be the + union's deserialize method. + :param is_client_mode: Whether the stream is being constructed for a client or + server implementation. + """ + initial_response_stream = _AWSEventReceiver( + payload_codec=payload_codec, + source=async_reader, + deserializer=deserializeable_response.deserialize, + is_client_mode=is_client_mode, + ) + initial_response = await initial_response_stream.receive() + if initial_response is None: + raise MissingInitialResponse() + + return cls( + payload_codec=payload_codec, + initial_response=initial_response, + async_reader=async_reader, + deserializer=deserializer, + is_client_mode=is_client_mode, + ) diff --git a/python-packages/aws-event-stream/aws_event_stream/events.py b/python-packages/aws-event-stream/aws_event_stream/events.py index 50da440be..ec77473f0 100644 --- a/python-packages/aws-event-stream/aws_event_stream/events.py +++ b/python-packages/aws-event-stream/aws_event_stream/events.py @@ -17,6 +17,7 @@ from struct import pack, unpack from typing import Literal, Self +from smithy_core.aio.interfaces import AsyncByteStream from smithy_core.interfaces import BytesReader from smithy_core.types import TimestampFormat @@ -314,6 +315,36 @@ def decode(cls, source: BytesReader) -> Self: ) return cls(prelude, message, crc) + @classmethod + async def decode_async(cls, source: AsyncByteStream) -> Self: + """Decode an event from an async byte stream. + + :param source: An object to read event bytes from. It must have a `read` method + that accepts a number of bytes to read. + + :returns: An Event representing the next event on the source. + """ + + prelude_bytes = await source.read(8) + prelude_crc_bytes = await source.read(4) + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + + total_length, headers_length = unpack("!II", prelude_bytes) + _validate_checksum(prelude_bytes, prelude_crc) + prelude = EventPrelude( + total_length=total_length, headers_length=headers_length, crc=prelude_crc + ) + + message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE) + crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) + + message = EventMessage( + headers_bytes=message_bytes[: prelude.headers_length], + payload=message_bytes[prelude.headers_length :], + ) + return cls(prelude, message, crc) + class _EventEncoder: """A utility class that encodes message bytes into binary events.""" diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py index 8c5dc0837..ec705c611 100644 --- a/python-packages/aws-event-stream/aws_event_stream/exceptions.py +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -91,3 +91,8 @@ def __init__(self, size: str, value: int): f"be 32-bit." ) super().__init__(message) + + +class MissingInitialResponse(EventError): + def __init__(self) -> None: + super().__init__("Expected an initial response, but none was found.") diff --git a/python-packages/aws-event-stream/pyproject.toml b/python-packages/aws-event-stream/pyproject.toml index 9e31443db..37a651614 100644 --- a/python-packages/aws-event-stream/pyproject.toml +++ b/python-packages/aws-event-stream/pyproject.toml @@ -25,6 +25,10 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Topic :: Software Development :: Libraries" ] +dependencies=[ + "smithy_core==0.0.1", + "smithy_event_stream==0.0.1", +] [project.urls] source = "https://github.com/awslabs/smithy-python/tree/develop/python-packages/aws-event-stream" diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py index 9dfa84dfe..301533905 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -7,7 +7,7 @@ from smithy_json import JSONCodec from aws_event_stream._private.deserializers import EventDeserializer -from aws_event_stream.events import EventMessage +from aws_event_stream.events import Event, EventMessage from aws_event_stream.exceptions import UnmodeledEventError from . import ( @@ -21,24 +21,24 @@ @pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamDeserializer().deserialize(deserializer) assert result == expected def test_deserialize_initial_request(): expected, given = INITIAL_REQUEST_CASE - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected def test_deserialize_initial_response(): expected, given = INITIAL_RESPONSE_CASE - source = BytesIO(given.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(given.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -51,8 +51,8 @@ def test_deserialize_unmodeled_error(): ":error-message": "An internal server error occurred.", } ) - source = BytesIO(message.encode()) - deserializer = EventDeserializer(source=source, payload_codec=JSONCodec()) + source = Event.decode(BytesIO(message.encode())) + deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) with pytest.raises(UnmodeledEventError, match="InternalError"): EventStreamOperationInputOutput.deserialize(deserializer) diff --git a/python-packages/aws-event-stream/tests/unit/test_events.py b/python-packages/aws-event-stream/tests/unit/test_events.py index 0bbdfec95..d6c534bb8 100644 --- a/python-packages/aws-event-stream/tests/unit/test_events.py +++ b/python-packages/aws-event-stream/tests/unit/test_events.py @@ -6,6 +6,7 @@ from io import BytesIO import pytest +from smithy_core.aio.types import AsyncBytesReader from aws_event_stream.events import ( MAX_HEADER_VALUE_BYTE_LENGTH, @@ -506,6 +507,11 @@ def test_decode(encoded: bytes, expected: Event): assert Event.decode(BytesIO(encoded)) == expected +@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) +async def test_decode_async(encoded: bytes, expected: Event): + assert await Event.decode_async(AsyncBytesReader(encoded)) == expected + + @pytest.mark.parametrize("expected,event", POSITIVE_CASES) def test_encode(expected: bytes, event: Event): assert event.message.encode() == expected @@ -530,6 +536,25 @@ def test_negative_cases(encoded: bytes, expected: type[Exception]): Event.decode(BytesIO(encoded)) +@pytest.mark.parametrize( + "encoded,expected", + NEGATIVE_CASES, + ids=[ + "corrupted-length", + "corrupted-payload", + "corrupted-headers", + "corrupted-headers-length", + "duplicate-header", + "invalid-headers-length", + "invalid-header-value-length", + "invalid-payload-length", + ], +) +async def test_negative_cases_async(encoded: bytes, expected: type[Exception]): + with pytest.raises(expected): + await Event.decode_async(AsyncBytesReader(encoded)) + + def test_event_prelude_rejects_long_headers(): headers_length = MAX_HEADERS_LENGTH + 1 total_length = headers_length + 16 diff --git a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py index bdf71d54a..79fdb60ca 100644 --- a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py +++ b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py @@ -14,6 +14,20 @@ class AsyncByteStream(Protocol): async def read(self, size: int = -1) -> bytes: ... +@runtime_checkable +class AsyncWriter(Protocol): + """An object with an async write method.""" + + async def write(self, data: bytes) -> None: ... + + +@runtime_checkable +class AsyncCloseable(Protocol): + """An object that can asynchronously close.""" + + async def close(self): ... + + # A union of all acceptable streaming blob types. Deserialized payloads will # always return a ByteStream, or AsyncByteStream if async is enabled. type StreamingBlob = SyncStreamingBlob | AsyncByteStream | AsyncIterable[bytes] diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py new file mode 100644 index 000000000..04f8b7b76 --- /dev/null +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/__init__.py @@ -0,0 +1,2 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py index 5708f143f..a38468d0b 100644 --- a/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py +++ b/python-packages/smithy-event-stream/smithy_event_stream/aio/interfaces.py @@ -6,7 +6,7 @@ from smithy_core.serializers import SerializeableShape -class InputEventStream[E: SerializeableShape](Protocol): +class AsyncEventPublisher[E: SerializeableShape](Protocol): """Asynchronously sends events to a service. This may be used as a context manager to ensure the stream is closed before exiting. @@ -30,7 +30,7 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self.close() -class OutputEventStream[E: DeserializeableShape](Protocol): +class AsyncEventReceiver[E: DeserializeableShape](Protocol): """Asynchronously receives events from a service. Events may be received via the ``receive`` method or by using this class as @@ -69,10 +69,8 @@ async def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): await self.close() -class EventStream[I: InputEventStream[Any] | None, O: OutputEventStream[Any] | None, R]( - Protocol -): - """A unidirectional or bidirectional event stream. +class DuplexEventStream[I: SerializeableShape, O: DeserializeableShape, R](Protocol): + """An event stream that both sends and receives messages. To ensure that streams are closed upon exiting, this class may be used as an async context manager. @@ -104,30 +102,46 @@ async def handle_output(stream: EventStream) -> None: return """ - input_stream: I - """An event stream that sends events to the service. + input_stream: AsyncEventPublisher[I] + """An event stream that sends events to the service.""" - This value will be None if the operation has no input stream. - """ + # Exposing response and output_stream via @property allows implementations that + # don't have it immediately available to do things like put a future in + # await_output or otherwise reasonably implement that method while still allowing + # them to inherit directly from the protocol class. + _output_stream: AsyncEventReceiver[O] | None = None + _response: R | None = None - output_stream: O | None = None - """An event stream that receives events from the service. + @property + def output_stream(self) -> AsyncEventReceiver[O] | None: + """An event stream that receives events from the service. - This value may be None until ``await_output`` has been called. + This value may be None until ``await_output`` has been called. - This value will also be None if the operation has no output stream. - """ + This value will also be None if the operation has no output stream. + """ + return self._output_stream - response: R | None = None - """The initial response from the service. + @output_stream.setter + def output_stream(self, value: AsyncEventReceiver[O]) -> None: + self._output_stream = value - This value may be None until ``await_output`` has been called. + @property + def response(self) -> R | None: + """The initial response from the service. - This may include context necessary to interpret output events or prepare - input events. It will always be available before any events. - """ + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + return self._response - async def await_output(self) -> tuple[R, O]: + @response.setter + def response(self, value: R) -> None: + self._response = value + + async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: """Await the operation's output. The EventStream will be returned as soon as the input stream is ready to @@ -146,17 +160,6 @@ async def await_output(self) -> tuple[R, O]: :returns: A tuple containing the initial response and output stream. If the operation has no output stream, the second value will be None. """ - if self.response is not None: - self.response, self.output_stream = await self._await_output() - - return self._response, self._output_stream # type: ignore - - async def _await_output(self) -> tuple[R, O]: - """Await the operation's output without caching. - - This method is meant to be used with the default implementation of await_output. - It should return the output directly without caching. - """ ... async def close(self) -> None: @@ -167,8 +170,123 @@ async def close(self) -> None: if self.output_stream is None: _, self.output_stream = await self.await_output() - if self.output_stream is not None: - await self.output_stream.close() + await self.input_stream.close() + await self.output_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class InputEventStream[I: SerializeableShape, R](Protocol): + """An event stream that streams messages to the service. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = PublishMessagesInput(chat_room="aws-python-sdk", username="hunter7") + + async with client.publish_messages(input=input) as stream: + stream.input_stream.send(MessageStreamMessage("High severity ticket alert!")) + await stream.await_output() + """ + + input_stream: AsyncEventPublisher[I] + """An event stream that sends events to the service.""" + + # Exposing response via @property allows implementations that don't have it + # immediately available to do things like put a future in await_output or + # otherwise reasonably implement that method while still allowing them to + # inherit directly from the protocol class. + _response: R | None = None + + @property + def response(self) -> R | None: + """The initial response from the service. + + This value may be None until ``await_output`` has been called. + + This may include context necessary to interpret output events or prepare + input events. It will always be available before any events. + """ + return self._response + + @response.setter + def response(self, value: R) -> None: + self._response = value + + async def await_output(self) -> R: + """Await the operation's output. + + The InputEventStream will be returned as soon as the input stream is ready to + receive events, which may be before the initial response has been received. + + Awaiting this method will wait until the initial response was received. The + operation response will be returned by this operation and also cached in + ``response``. + + The default implementation of this method performs the caching behavior, + delegating to the abstract ``_await_output`` method to actually retrieve the + operation response. + + :returns: The operation's response. + """ + ... + + async def close(self) -> None: + """Closes the event stream.""" + await self.input_stream.close() + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any): + await self.close() + + +class OutputEventStream[O: DeserializeableShape, R](Protocol): + """An event stream that streams messages from the service. + + To ensure that streams are closed upon exiting, this class may be used as an async + context manager. + + .. code-block:: python + + async def main(): + client = ChatClient() + input = ReceiveMessagesInput(chat_room="aws-python-sdk") + + async with client.receive_messages(input=input) as stream: + async for event in stream.output_stream: + match event: + case MessageStreamMessage(): + print(event.value) + case _: + return + """ + + output_stream: AsyncEventReceiver[O] + """An event stream that receives events from the service. + + This value will also be None if the operation has no output stream. + """ + + response: R + """The initial response from the service. + + This may include context necessary to interpret output events or prepare input + events. It will always be available before any events. + """ + + async def close(self) -> None: + """Closes the event stream.""" + await self.output_stream.close() async def __aenter__(self) -> Self: return self