Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@ def __init__(
self._source = source
self._is_client_mode = is_client_mode
self._deserializer = deserializer
self._closed = False

async def receive(self) -> E | None:
event = await Event.decode_async(self._source)
if self._closed:
return None

try:
event = await Event.decode_async(self._source)
except Exception as e:
await self.close()
if not isinstance(e, EventError):
raise IOError("Failed to read from stream.") from e
raise

if event is None:
return None

Expand All @@ -57,10 +68,18 @@ async def receive(self) -> E | None:
return result

async def close(self) -> None:
if self._closed:
return
self._closed = True

if (close := getattr(self._source, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

@property
def closed(self) -> bool:
return self._closed


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def __init__(
self._serializer = EventSerializer(
payload_codec=payload_codec, is_client_mode=is_client_mode
)
self._closed = False

async def send(self, event: E) -> None:
if self._closed:
raise IOError("Attempted to write to closed stream.")
event.serialize(self._serializer)
result = self._serializer.get_result()
if result is None:
Expand All @@ -60,13 +63,27 @@ async def send(self, event: E) -> None:
)
if self._signer is not None:
result = self._signer(result)
await self._writer.write(result.encode())

encoded_result = result.encode()
try:
await self._writer.write(encoded_result)
except Exception as e:
await self.close()
raise IOError("Failed to write to stream.") from e

async def close(self) -> None:
if self._closed:
return
self._closed = True

if (close := getattr(self._writer, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

@property
def closed(self) -> bool:
return self._closed


class EventSerializer(SpecificShapeSerializer):
def __init__(
Expand Down
52 changes: 30 additions & 22 deletions packages/aws-event-stream/src/aws_event_stream/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,35 @@ def __init__(
self.response: R | None = None

async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]:
async_reader = AsyncBytesReader((await self._awaitable_response).body)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is formatted a bit unfortunately. I just wrapped it in a try/except

if self.output_stream is None:
self.output_stream = _AWSEventReceiver[O](
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializer,
is_client_mode=self._is_client_mode,
)

if self.response is None:
if self._deserializeable_response is None:
initial_response = await self._awaitable_output
else:
initial_response_stream = _AWSEventReceiver(
try:
async_reader = AsyncBytesReader((await self._awaitable_response).body)
if self.output_stream is None:
self.output_stream = _AWSEventReceiver[O](
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializeable_response.deserialize,
deserializer=self._deserializer,
is_client_mode=self._is_client_mode,
)
initial_response = await initial_response_stream.receive()
if initial_response is None:
raise MissingInitialResponse()
self.response = initial_response
else:
initial_response = self.response

if self.response is None:
if self._deserializeable_response is None:
initial_response = await self._awaitable_output
else:
initial_response_stream = _AWSEventReceiver(
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializeable_response.deserialize,
is_client_mode=self._is_client_mode,
)
initial_response = await initial_response_stream.receive()
if initial_response is None:
raise MissingInitialResponse()
self.response = initial_response
else:
initial_response = self.response
except Exception:
await self.input_stream.close()
raise

return initial_response, self.output_stream

Expand Down Expand Up @@ -137,7 +141,11 @@ def __init__(

async def await_output(self) -> R:
if self.response is None:
self.response = await self._awaitable_response
try:
self.response = await self._awaitable_response
except Exception:
await self.input_stream.close()
raise
return self.response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,43 @@ def test_deserialize_unmodeled_error():

with pytest.raises(UnmodeledEventError, match="InternalError"):
EventStreamOperationInputOutput.deserialize(deserializer)


async def test_receiver_closes_source() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)
assert not receiver.closed
assert not source.closed
await receiver.close()
assert receiver.closed
assert source.closed


async def test_read_closed_receiver() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)

await receiver.close()
assert receiver.closed
assert await receiver.receive() is None


async def test_read_closed_receiver_source() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)

await source.close()
assert source.closed
assert not receiver.closed
with pytest.raises(IOError):
await receiver.receive()
assert receiver.closed
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any

import pytest
from smithy_core.serializers import SerializeableShape
from smithy_core.aio.types import AsyncBytesProvider
from smithy_json import JSONCodec

from aws_event_stream._private.serializers import EventSerializer
from aws_event_stream._private.serializers import (
EventSerializer,
AWSAsyncEventPublisher,
)
from aws_event_stream.events import EventMessage

from . import EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE
Expand Down Expand Up @@ -36,3 +42,41 @@ def test_serialize_initial_request():

def test_serialize_initial_response():
test_event_serializer_server_mode(*INITIAL_RESPONSE_CASE)


async def test_publisher_closes_reader():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

assert not publisher.closed
assert not writer.closed
await publisher.close()
assert publisher.closed
assert writer.closed


async def test_send_after_close():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

await publisher.close()
assert publisher.closed
with pytest.raises(IOError):
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])


async def test_send_to_closed_writer():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

await writer.close()
with pytest.raises(IOError):
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])

assert publisher.closed
15 changes: 5 additions & 10 deletions packages/smithy-core/src/smithy_core/aio/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..exceptions import SmithyException
from ..interfaces import BytesReader
from .interfaces import AsyncByteStream, StreamingBlob
from .utils import close

# The default chunk size for iterating streams.
_DEFAULT_CHUNK_SIZE = 1024
Expand Down Expand Up @@ -114,12 +115,9 @@ def closed(self) -> bool:

async def close(self) -> None:
"""Closes the stream, as well as the underlying stream where possible."""
if (close := getattr(self._data, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

self._data = None
self._closed = True
await close(self._data)
self._data = None


class SeekableAsyncBytesReader:
Expand Down Expand Up @@ -250,12 +248,9 @@ def closed(self) -> bool:

async def close(self) -> None:
"""Closes the stream, as well as the underlying stream where possible."""
if (close := getattr(self._data_source, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

self._data_source = None
self._buffer.close()
await close(self._data_source)
self._data_source = None


class _AsyncByteStreamIterator:
Expand Down
15 changes: 10 additions & 5 deletions packages/smithy-core/src/smithy_core/aio/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from asyncio import sleep
from asyncio import sleep, iscoroutine
from collections.abc import AsyncIterable, Iterable
from typing import TypeVar
from typing import Any

from ..exceptions import AsyncBodyException
from ..interfaces import BytesReader
from ..interfaces import StreamingBlob as SyncStreamingBlob
from .interfaces import AsyncByteStream, StreamingBlob

_ListEl = TypeVar("_ListEl")


async def async_list(lst: Iterable[_ListEl]) -> AsyncIterable[_ListEl]:
async def async_list[E](lst: Iterable[E]) -> AsyncIterable[E]:
"""Turn an Iterable into an AsyncIterable."""
for x in lst:
await sleep(0)
Expand Down Expand Up @@ -53,3 +51,10 @@ def read_streaming_blob(body: StreamingBlob) -> bytes:
raise AsyncBodyException(
f"Expected type {SyncStreamingBlob}, but was {type(body)}"
)


async def close(stream: Any) -> None:
"""Close a stream, awaiting it if it's async."""
if (close := getattr(stream, "close", None)) is not None:
if iscoroutine(result := close()):
await result
24 changes: 24 additions & 0 deletions packages/smithy-core/tests/unit/aio/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from io import BytesIO

from smithy_core.aio.types import AsyncBytesProvider
from smithy_core.aio.utils import close


async def test_close_sync_closeable() -> None:
stream = BytesIO()
assert not stream.closed
await close(stream)
assert stream.closed


async def test_close_async_closeable() -> None:
stream = AsyncBytesProvider()
assert not stream.closed
await close(stream)
assert stream.closed


async def test_close_non_closeable() -> None:
await close(b"foo")
Loading