Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@ def __init__(
self._source = source
self._is_client_mode = is_client_mode
self._deserializer = deserializer
self._closed = False

async def receive(self) -> E | None:
event = await Event.decode_async(self._source)
if self._closed:
return None

try:
event = await Event.decode_async(self._source)
except Exception as e:
await self.close()
if not isinstance(e, EventError):
raise IOError("Failed to read from stream.") from e
raise

if event is None:
return None

Expand All @@ -57,10 +68,18 @@ async def receive(self) -> E | None:
return result

async def close(self) -> None:
if self._closed:
return
self._closed = True

if (close := getattr(self._source, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

@property
def closed(self) -> bool:
return self._closed


class EventDeserializer(SpecificShapeDeserializer):
def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ def __init__(
self._serializer = EventSerializer(
payload_codec=payload_codec, is_client_mode=is_client_mode
)
self._closed = False

async def send(self, event: E) -> None:
if self._closed:
raise IOError("Attempted to write to closed stream.")
event.serialize(self._serializer)
result = self._serializer.get_result()
if result is None:
Expand All @@ -60,13 +63,27 @@ async def send(self, event: E) -> None:
)
if self._signer is not None:
result = self._signer(result)
await self._writer.write(result.encode())

encoded_result = result.encode()
try:
await self._writer.write(encoded_result)
except Exception as e:
await self.close()
raise IOError("Failed to write to stream.") from e

async def close(self) -> None:
if self._closed:
return
self._closed = True

if (close := getattr(self._writer, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

@property
def closed(self) -> bool:
return self._closed


class EventSerializer(SpecificShapeSerializer):
def __init__(
Expand Down
52 changes: 30 additions & 22 deletions packages/aws-event-stream/src/aws_event_stream/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,31 +75,35 @@ def __init__(
self.response: R | None = None

async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]:
async_reader = AsyncBytesReader((await self._awaitable_response).body)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is formatted a bit unfortunately. I just wrapped it in a try/except

if self.output_stream is None:
self.output_stream = _AWSEventReceiver[O](
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializer,
is_client_mode=self._is_client_mode,
)

if self.response is None:
if self._deserializeable_response is None:
initial_response = await self._awaitable_output
else:
initial_response_stream = _AWSEventReceiver(
try:
async_reader = AsyncBytesReader((await self._awaitable_response).body)
if self.output_stream is None:
self.output_stream = _AWSEventReceiver[O](
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializeable_response.deserialize,
deserializer=self._deserializer,
is_client_mode=self._is_client_mode,
)
initial_response = await initial_response_stream.receive()
if initial_response is None:
raise MissingInitialResponse()
self.response = initial_response
else:
initial_response = self.response

if self.response is None:
if self._deserializeable_response is None:
initial_response = await self._awaitable_output
else:
initial_response_stream = _AWSEventReceiver(
payload_codec=self._payload_codec,
source=async_reader,
deserializer=self._deserializeable_response.deserialize,
is_client_mode=self._is_client_mode,
)
initial_response = await initial_response_stream.receive()
if initial_response is None:
raise MissingInitialResponse()
self.response = initial_response
else:
initial_response = self.response
except Exception:
await self.input_stream.close()
raise

return initial_response, self.output_stream

Expand Down Expand Up @@ -137,7 +141,11 @@ def __init__(

async def await_output(self) -> R:
if self.response is None:
self.response = await self._awaitable_response
try:
self.response = await self._awaitable_response
except Exception:
await self.input_stream.close()
raise
return self.response


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,43 @@ def test_deserialize_unmodeled_error():

with pytest.raises(UnmodeledEventError, match="InternalError"):
EventStreamOperationInputOutput.deserialize(deserializer)


async def test_receiver_closes_source() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)
assert not receiver.closed
assert not source.closed
await receiver.close()
assert receiver.closed
assert source.closed


async def test_read_closed_receiver() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)

await receiver.close()
assert receiver.closed
assert await receiver.receive() is None


async def test_read_closed_receiver_source() -> None:
source = AsyncBytesReader(b"")
deserializer = EventStreamDeserializer()
receiver = AWSAsyncEventReceiver[Any](
payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize
)

await source.close()
assert source.closed
assert not receiver.closed
with pytest.raises(IOError):
await receiver.receive()
assert receiver.closed
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any

import pytest
from smithy_core.serializers import SerializeableShape
from smithy_core.aio.types import AsyncBytesProvider
from smithy_json import JSONCodec

from aws_event_stream._private.serializers import EventSerializer
from aws_event_stream._private.serializers import (
EventSerializer,
AWSAsyncEventPublisher,
)
from aws_event_stream.events import EventMessage

from . import EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE
Expand Down Expand Up @@ -36,3 +42,41 @@ def test_serialize_initial_request():

def test_serialize_initial_response():
test_event_serializer_server_mode(*INITIAL_RESPONSE_CASE)


async def test_publisher_closes_reader():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

assert not publisher.closed
assert not writer.closed
await publisher.close()
assert publisher.closed
assert writer.closed


async def test_send_after_close():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

await publisher.close()
assert publisher.closed
with pytest.raises(IOError):
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])


async def test_send_to_closed_writer():
writer = AsyncBytesProvider()
publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher(
payload_codec=JSONCodec(), async_writer=writer
)

await writer.close()
with pytest.raises(IOError):
await publisher.send(EVENT_STREAM_SERDE_CASES[0][0])

assert publisher.closed
2 changes: 1 addition & 1 deletion packages/smithy-core/src/smithy_core/aio/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ def closed(self) -> bool:

async def close(self) -> None:
"""Closes the stream, as well as the underlying stream where possible."""
self._closed = True
if (close := getattr(self._data, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

self._data = None
self._closed = True


class SeekableAsyncBytesReader:
Expand Down
Loading