diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py index 864038d57..c9cfdbb4f 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/deserializers.py @@ -40,9 +40,20 @@ def __init__( self._source = source self._is_client_mode = is_client_mode self._deserializer = deserializer + self._closed = False async def receive(self) -> E | None: - event = await Event.decode_async(self._source) + if self._closed: + return None + + try: + event = await Event.decode_async(self._source) + except Exception as e: + await self.close() + if not isinstance(e, EventError): + raise IOError("Failed to read from stream.") from e + raise + if event is None: return None @@ -57,10 +68,18 @@ async def receive(self) -> E | None: return result async def close(self) -> None: + if self._closed: + return + self._closed = True + if (close := getattr(self._source, "close", None)) is not None: if asyncio.iscoroutine(result := close()): await result + @property + def closed(self) -> bool: + return self._closed + class EventDeserializer(SpecificShapeDeserializer): def __init__( diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py index 32bb4935f..933381c4b 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py @@ -50,8 +50,11 @@ def __init__( self._serializer = EventSerializer( payload_codec=payload_codec, is_client_mode=is_client_mode ) + self._closed = False async def send(self, event: E) -> None: + if self._closed: + raise IOError("Attempted to write to closed stream.") event.serialize(self._serializer) result = self._serializer.get_result() if result is None: @@ -60,13 +63,27 @@ async def send(self, event: E) -> None: ) if self._signer is not None: result = self._signer(result) - await self._writer.write(result.encode()) + + encoded_result = result.encode() + try: + await self._writer.write(encoded_result) + except Exception as e: + await self.close() + raise IOError("Failed to write to stream.") from e async def close(self) -> None: + if self._closed: + return + self._closed = True + if (close := getattr(self._writer, "close", None)) is not None: if asyncio.iscoroutine(result := close()): await result + @property + def closed(self) -> bool: + return self._closed + class EventSerializer(SpecificShapeSerializer): def __init__( diff --git a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py index 76943f52e..a71053c91 100644 --- a/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py +++ b/packages/aws-event-stream/src/aws_event_stream/aio/__init__.py @@ -75,31 +75,35 @@ def __init__( self.response: R | None = None async def await_output(self) -> tuple[R, AsyncEventReceiver[O]]: - async_reader = AsyncBytesReader((await self._awaitable_response).body) - if self.output_stream is None: - self.output_stream = _AWSEventReceiver[O]( - payload_codec=self._payload_codec, - source=async_reader, - deserializer=self._deserializer, - is_client_mode=self._is_client_mode, - ) - - if self.response is None: - if self._deserializeable_response is None: - initial_response = await self._awaitable_output - else: - initial_response_stream = _AWSEventReceiver( + try: + async_reader = AsyncBytesReader((await self._awaitable_response).body) + if self.output_stream is None: + self.output_stream = _AWSEventReceiver[O]( payload_codec=self._payload_codec, source=async_reader, - deserializer=self._deserializeable_response.deserialize, + deserializer=self._deserializer, is_client_mode=self._is_client_mode, ) - initial_response = await initial_response_stream.receive() - if initial_response is None: - raise MissingInitialResponse() - self.response = initial_response - else: - initial_response = self.response + + if self.response is None: + if self._deserializeable_response is None: + initial_response = await self._awaitable_output + else: + initial_response_stream = _AWSEventReceiver( + payload_codec=self._payload_codec, + source=async_reader, + deserializer=self._deserializeable_response.deserialize, + is_client_mode=self._is_client_mode, + ) + initial_response = await initial_response_stream.receive() + if initial_response is None: + raise MissingInitialResponse() + self.response = initial_response + else: + initial_response = self.response + except Exception: + await self.input_stream.close() + raise return initial_response, self.output_stream @@ -137,7 +141,11 @@ def __init__( async def await_output(self) -> R: if self.response is None: - self.response = await self._awaitable_response + try: + self.response = await self._awaitable_response + except Exception: + await self.input_stream.close() + raise return self.response diff --git a/packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/packages/aws-event-stream/tests/unit/_private/test_deserializers.py index a153864e9..e4213c0f8 100644 --- a/packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -89,3 +89,43 @@ def test_deserialize_unmodeled_error(): with pytest.raises(UnmodeledEventError, match="InternalError"): EventStreamOperationInputOutput.deserialize(deserializer) + + +async def test_receiver_closes_source() -> None: + source = AsyncBytesReader(b"") + deserializer = EventStreamDeserializer() + receiver = AWSAsyncEventReceiver[Any]( + payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize + ) + assert not receiver.closed + assert not source.closed + await receiver.close() + assert receiver.closed + assert source.closed + + +async def test_read_closed_receiver() -> None: + source = AsyncBytesReader(b"") + deserializer = EventStreamDeserializer() + receiver = AWSAsyncEventReceiver[Any]( + payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize + ) + + await receiver.close() + assert receiver.closed + assert await receiver.receive() is None + + +async def test_read_closed_receiver_source() -> None: + source = AsyncBytesReader(b"") + deserializer = EventStreamDeserializer() + receiver = AWSAsyncEventReceiver[Any]( + payload_codec=JSONCodec(), source=source, deserializer=deserializer.deserialize + ) + + await source.close() + assert source.closed + assert not receiver.closed + with pytest.raises(IOError): + await receiver.receive() + assert receiver.closed diff --git a/packages/aws-event-stream/tests/unit/_private/test_serializers.py b/packages/aws-event-stream/tests/unit/_private/test_serializers.py index c3f16277e..a8134acdf 100644 --- a/packages/aws-event-stream/tests/unit/_private/test_serializers.py +++ b/packages/aws-event-stream/tests/unit/_private/test_serializers.py @@ -1,10 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from typing import Any + import pytest from smithy_core.serializers import SerializeableShape +from smithy_core.aio.types import AsyncBytesProvider from smithy_json import JSONCodec -from aws_event_stream._private.serializers import EventSerializer +from aws_event_stream._private.serializers import ( + EventSerializer, + AWSAsyncEventPublisher, +) from aws_event_stream.events import EventMessage from . import EVENT_STREAM_SERDE_CASES, INITIAL_REQUEST_CASE, INITIAL_RESPONSE_CASE @@ -36,3 +42,41 @@ def test_serialize_initial_request(): def test_serialize_initial_response(): test_event_serializer_server_mode(*INITIAL_RESPONSE_CASE) + + +async def test_publisher_closes_reader(): + writer = AsyncBytesProvider() + publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher( + payload_codec=JSONCodec(), async_writer=writer + ) + + assert not publisher.closed + assert not writer.closed + await publisher.close() + assert publisher.closed + assert writer.closed + + +async def test_send_after_close(): + writer = AsyncBytesProvider() + publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher( + payload_codec=JSONCodec(), async_writer=writer + ) + + await publisher.close() + assert publisher.closed + with pytest.raises(IOError): + await publisher.send(EVENT_STREAM_SERDE_CASES[0][0]) + + +async def test_send_to_closed_writer(): + writer = AsyncBytesProvider() + publisher: AWSAsyncEventPublisher[Any] = AWSAsyncEventPublisher( + payload_codec=JSONCodec(), async_writer=writer + ) + + await writer.close() + with pytest.raises(IOError): + await publisher.send(EVENT_STREAM_SERDE_CASES[0][0]) + + assert publisher.closed diff --git a/packages/smithy-core/src/smithy_core/aio/types.py b/packages/smithy-core/src/smithy_core/aio/types.py index 6876654b6..91e0244b9 100644 --- a/packages/smithy-core/src/smithy_core/aio/types.py +++ b/packages/smithy-core/src/smithy_core/aio/types.py @@ -10,6 +10,7 @@ from ..exceptions import SmithyException from ..interfaces import BytesReader from .interfaces import AsyncByteStream, StreamingBlob +from .utils import close # The default chunk size for iterating streams. _DEFAULT_CHUNK_SIZE = 1024 @@ -114,12 +115,9 @@ def closed(self) -> bool: async def close(self) -> None: """Closes the stream, as well as the underlying stream where possible.""" - if (close := getattr(self._data, "close", None)) is not None: - if asyncio.iscoroutine(result := close()): - await result - - self._data = None self._closed = True + await close(self._data) + self._data = None class SeekableAsyncBytesReader: @@ -250,12 +248,9 @@ def closed(self) -> bool: async def close(self) -> None: """Closes the stream, as well as the underlying stream where possible.""" - if (close := getattr(self._data_source, "close", None)) is not None: - if asyncio.iscoroutine(result := close()): - await result - - self._data_source = None self._buffer.close() + await close(self._data_source) + self._data_source = None class _AsyncByteStreamIterator: diff --git a/packages/smithy-core/src/smithy_core/aio/utils.py b/packages/smithy-core/src/smithy_core/aio/utils.py index 9d1f77a01..fbd611ec8 100644 --- a/packages/smithy-core/src/smithy_core/aio/utils.py +++ b/packages/smithy-core/src/smithy_core/aio/utils.py @@ -1,18 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from asyncio import sleep +from asyncio import sleep, iscoroutine from collections.abc import AsyncIterable, Iterable -from typing import TypeVar +from typing import Any from ..exceptions import AsyncBodyException from ..interfaces import BytesReader from ..interfaces import StreamingBlob as SyncStreamingBlob from .interfaces import AsyncByteStream, StreamingBlob -_ListEl = TypeVar("_ListEl") - -async def async_list(lst: Iterable[_ListEl]) -> AsyncIterable[_ListEl]: +async def async_list[E](lst: Iterable[E]) -> AsyncIterable[E]: """Turn an Iterable into an AsyncIterable.""" for x in lst: await sleep(0) @@ -53,3 +51,10 @@ def read_streaming_blob(body: StreamingBlob) -> bytes: raise AsyncBodyException( f"Expected type {SyncStreamingBlob}, but was {type(body)}" ) + + +async def close(stream: Any) -> None: + """Close a stream, awaiting it if it's async.""" + if (close := getattr(stream, "close", None)) is not None: + if iscoroutine(result := close()): + await result diff --git a/packages/smithy-core/tests/unit/aio/test_utils.py b/packages/smithy-core/tests/unit/aio/test_utils.py new file mode 100644 index 000000000..d9084db8b --- /dev/null +++ b/packages/smithy-core/tests/unit/aio/test_utils.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from io import BytesIO + +from smithy_core.aio.types import AsyncBytesProvider +from smithy_core.aio.utils import close + + +async def test_close_sync_closeable() -> None: + stream = BytesIO() + assert not stream.closed + await close(stream) + assert stream.closed + + +async def test_close_async_closeable() -> None: + stream = AsyncBytesProvider() + assert not stream.closed + await close(stream) + assert stream.closed + + +async def test_close_non_closeable() -> None: + await close(b"foo") diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index f0a6ec69c..0e17eba49 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -3,10 +3,12 @@ # pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false # flake8: noqa: F811 import asyncio +from asyncio import Future as AsyncFuture from concurrent.futures import Future as ConcurrentFuture from collections import deque from collections.abc import AsyncGenerator, AsyncIterable from copy import deepcopy +from functools import partial from io import BytesIO, BufferedIOBase from typing import TYPE_CHECKING, Any @@ -33,6 +35,7 @@ from smithy_core import interfaces as core_interfaces from smithy_core.aio.types import AsyncBytesReader +from smithy_core.aio.utils import close from smithy_core.exceptions import MissingDependencyException from .. import Field, Fields @@ -104,6 +107,7 @@ def __repr__(self) -> str: class CRTResponseBody: def __init__(self) -> None: self._stream: "crt_http.HttpClientStream | None" = None + self._completion_future: AsyncFuture[int] | None = None self._chunk_futures: deque[ConcurrentFuture[bytes]] = deque() # deque is thread safe and the crt is only going to be writing @@ -116,7 +120,9 @@ def set_stream(self, stream: "crt_http.HttpClientStream") -> None: if self._stream is not None: raise SmithyHTTPException("Stream already set on AWSCRTHTTPResponse object") self._stream = stream - self._stream.completion_future.add_done_callback(self._on_complete) + concurrent_future: ConcurrentFuture[int] = stream.completion_future + self._completion_future = asyncio.wrap_future(concurrent_future) + self._completion_future.add_done_callback(self._on_complete) self._stream.activate() def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback @@ -128,13 +134,13 @@ def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback self._received_chunks.append(chunk) async def next(self) -> bytes: - if self._stream is None: + if self._completion_future is None: raise SmithyHTTPException("Stream not set") # TODO: update backpressure window once CRT supports it if self._received_chunks: return self._received_chunks.popleft() - elif self._stream.completion_future.done(): + elif self._completion_future.done(): return b"" else: future = ConcurrentFuture[bytes]() @@ -142,7 +148,7 @@ async def next(self) -> bytes: return await asyncio.wrap_future(future) def _on_complete( - self, completion_future: ConcurrentFuture[int] + self, completion_future: AsyncFuture[int] ) -> None: # pragma: crt-callback for future in self._chunk_futures: future.set_result(b"") @@ -231,7 +237,7 @@ async def send( :param request: The request including destination URI, fields, payload. :param request_config: Configuration specific to this request. """ - crt_request = await self._marshal_request(request) + crt_request, crt_body = await self._marshal_request(request) connection = await self._get_connection(request.destination) response_body = CRTResponseBody() response_factory = CRTResponseFactory(response_body) @@ -242,8 +248,17 @@ async def send( ) response_factory.set_done_callback(crt_stream) response_body.set_stream(crt_stream) + crt_stream.completion_future.add_done_callback( + partial(self._close_input_body, body=crt_body) + ) return await response_factory.await_response() + def _close_input_body( + self, future: ConcurrentFuture[int], *, body: "BufferableByteStream | BytesIO" + ) -> None: + if future.exception(timeout=0): + body.close() + async def _create_connection( self, url: core_interfaces.URI ) -> "crt_http.HttpClientConnection": @@ -314,7 +329,7 @@ def _render_path(self, url: core_interfaces.URI) -> str: async def _marshal_request( self, request: http_aio_interfaces.HTTPRequest - ) -> "crt_http.HttpRequest": + ) -> tuple["crt_http.HttpRequest", "BufferableByteStream | BytesIO"]: """Create :py:class:`awscrt.http.HttpRequest` from :py:class:`smithy_http.aio.HTTPRequest`""" headers_list = [] @@ -343,12 +358,12 @@ async def _marshal_request( crt_body = BytesIO(body) else: # If the body is async, or potentially very large, start up a task to read - # it into the BytesIO object that CRT needs. By using asyncio.create_task - # we'll start the coroutine without having to explicitly await it. + # it into the intermediate object that CRT needs. By using + # asyncio.create_task we'll start the coroutine without having to + # explicitly await it. crt_body = BufferableByteStream() + if not isinstance(body, AsyncIterable): - # If the body isn't already an async iterable, wrap it in one. Objects - # with read methods will be read in chunks so as not to exhaust memory. body = AsyncBytesReader(body) # Start the read task in the background. @@ -365,13 +380,19 @@ async def _marshal_request( headers=headers, body_stream=crt_body, ) - return crt_request + return crt_request, crt_body async def _consume_body_async( self, source: AsyncIterable[bytes], dest: "BufferableByteStream" ) -> None: - async for chunk in source: - dest.write(chunk) + try: + async for chunk in source: + dest.write(chunk) + except Exception: + dest.close() + raise + finally: + await close(source) dest.end_stream() def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient": diff --git a/packages/smithy-http/tests/unit/aio/test_crt.py b/packages/smithy-http/tests/unit/aio/test_crt.py index d766601e6..89b04bcd0 100644 --- a/packages/smithy-http/tests/unit/aio/test_crt.py +++ b/packages/smithy-http/tests/unit/aio/test_crt.py @@ -1,14 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio from copy import deepcopy from io import BytesIO +from unittest.mock import Mock +from concurrent.futures import Future as ConcurrentFuture import pytest +from awscrt.http import HttpClientStream # type: ignore from smithy_core import URI from smithy_http import Fields from smithy_http.aio import HTTPRequest -from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream +from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream, CRTResponseBody def test_deepcopy_client() -> None: @@ -26,7 +30,7 @@ async def test_client_marshal_request() -> None: body=BytesIO(), fields=Fields(), ) - crt_request = await client._marshal_request(request) # type: ignore + crt_request, _ = await client._marshal_request(request) # type: ignore assert crt_request.headers.get("host") == "example.com" # type: ignore assert crt_request.headers.get("accept") == "*/*" # type: ignore assert crt_request.method == "GET" # type: ignore @@ -136,3 +140,77 @@ def test_end_stream() -> None: assert not stream.closed assert stream.read() == b"foo" assert stream.closed + + +async def test_response_body_completed_stream() -> None: + completion_future = ConcurrentFuture[int]() + mock_stream = Mock(spec=HttpClientStream) + mock_stream.completion_future = completion_future + + response_body = CRTResponseBody() + response_body.set_stream(mock_stream) + completion_future.set_result(200) + + assert await response_body.next() == b"" + + +async def test_response_body_empty_stream() -> None: + completion_future = ConcurrentFuture[int]() + mock_stream = Mock(spec=HttpClientStream) + mock_stream.completion_future = completion_future + + response_body = CRTResponseBody() + response_body.set_stream(mock_stream) + + read_task = asyncio.create_task(response_body.next()) + + # Sleep briefly so the read task gets priority. It should + # add a chunk future and then await it. + await asyncio.sleep(0.01) + + assert len(response_body._chunk_futures) == 1 # type: ignore + response_body.on_body(b"foo") + assert await read_task == b"foo" + + +async def test_response_body_stream_completion_clears_buffer() -> None: + completion_future = ConcurrentFuture[int]() + mock_stream = Mock(spec=HttpClientStream) + mock_stream.completion_future = completion_future + + response_body = CRTResponseBody() + response_body.set_stream(mock_stream) + + read_tasks = ( + asyncio.create_task(response_body.next()), + asyncio.create_task(response_body.next()), + asyncio.create_task(response_body.next()), + asyncio.create_task(response_body.next()), + ) + + # Sleep briefly so the read tasks gets priority. It should + # add a chunk future and then await it. + await asyncio.sleep(0.01) + + assert len(response_body._chunk_futures) == 4 # type: ignore + completion_future.set_result(200) + await asyncio.sleep(0.01) + + # Tasks should have been drained + assert len(response_body._chunk_futures) == 0 # type: ignore + + # Tasks should still be awaited, and should all return empty + results = asyncio.gather(*read_tasks) + assert results.result() == [b"", b"", b"", b""] + + +async def test_response_body_non_empty_stream() -> None: + completion_future = ConcurrentFuture[int]() + mock_stream = Mock(spec=HttpClientStream) + mock_stream.completion_future = completion_future + + response_body = CRTResponseBody() + response_body.set_stream(mock_stream) + response_body.on_body(b"foo") + + assert await response_body.next() == b"foo"