diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index 5bfff8ed4..a1fa3619d 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -51,7 +51,10 @@ async def receive(self) -> E | None: payload_codec=self._payload_codec, is_client_mode=self._is_client_mode, ) - return self._deserializer(deserializer) + result = self._deserializer(deserializer) + if isinstance(getattr(result, "value"), Exception): + raise result.value # type: ignore + return result async def close(self) -> None: if (close := getattr(self._source, "close", None)) is not None: diff --git a/python-packages/aws-event-stream/tests/unit/_private/__init__.py b/python-packages/aws-event-stream/tests/unit/_private/__init__.py index 54391d038..150f2344b 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/__init__.py +++ b/python-packages/aws-event-stream/tests/unit/_private/__init__.py @@ -341,7 +341,7 @@ def serialize_members(self, serializer: ShapeSerializer): @dataclass -class ErrorEvent: +class ErrorEvent(Exception): code: ClassVar[str] = "NoSuchResource" fault: ClassVar[Literal["client", "server"]] = "client" diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py index 3e4af5def..a153864e9 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -1,12 +1,17 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from io import BytesIO +from typing import Any import pytest +from smithy_core.aio.types import AsyncBytesReader from smithy_core.deserializers import DeserializeableShape from smithy_json import JSONCodec -from aws_event_stream._private.deserializers import EventDeserializer +from aws_event_stream._private.deserializers import ( + AWSAsyncEventReceiver, + EventDeserializer, +) from aws_event_stream.events import Event, EventMessage from aws_event_stream.exceptions import UnmodeledEventError @@ -14,11 +19,35 @@ EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE, + ErrorEvent, EventStreamDeserializer, + EventStreamErrorEvent, EventStreamOperationInputOutput, ) +@pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) +async def test_event_receiver(expected: DeserializeableShape, given: EventMessage): + source = AsyncBytesReader(given.encode()) + deserializer = EventStreamDeserializer() + receiver = AWSAsyncEventReceiver[Any]( + payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize + ) + + result: Any = None + + try: + result = await receiver.receive() + except ErrorEvent as e: + if isinstance(expected, EventStreamErrorEvent): + expected = expected.value + else: + raise + result = e + + assert result == expected + + @pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): source = Event.decode(BytesIO(given.encode()))