From ee284cdf183ae8217f34d9a6657177a681070370 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Mon, 25 Nov 2024 18:55:04 +0100 Subject: [PATCH] Check for async when closing Protocol isinstance does *not* check whether a function is sync or not. This adds in those checks to the various closeable checks. --- .../_private/deserializers.py | 8 ++-- .../aws_event_stream/_private/serializers.py | 8 ++-- .../smithy_core/aio/interfaces/__init__.py | 7 --- .../smithy-core/smithy_core/aio/types.py | 14 +++--- .../smithy-core/tests/unit/aio/test_types.py | 44 +++++++++++++++++-- 5 files changed, 59 insertions(+), 22 deletions(-) diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index f04b7e50e..bd2254119 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -1,9 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import datetime from collections.abc import Callable -from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable +from smithy_core.aio.interfaces import AsyncByteStream from smithy_core.codecs import Codec from smithy_core.deserializers import ( DeserializeableShape, @@ -45,8 +46,9 @@ async def receive(self) -> E | None: return self._deserializer(deserializer) async def close(self) -> None: - if isinstance(self._source, AsyncCloseable): - await self._source.close() + if (close := getattr(self._source, "close", None)) is not None: + if asyncio.iscoroutine(result := close()): + await result class EventDeserializer(SpecificShapeDeserializer): diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py index 3baf1c0e7..88b740fd3 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py @@ -1,12 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import asyncio import datetime from collections.abc import Callable, Iterator from contextlib import contextmanager from io import BytesIO from typing import Never -from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter +from smithy_core.aio.interfaces import AsyncWriter from smithy_core.codecs import Codec from smithy_core.exceptions import ExpectationNotMetException from smithy_core.schemas import Schema @@ -64,8 +65,9 @@ async def send(self, event: E) -> None: await self._writer.write(result.encode()) async def close(self) -> None: - if isinstance(self._writer, AsyncCloseable): - await self._writer.close() + if (close := getattr(self._writer, "close", None)) is not None: + if asyncio.iscoroutine(result := close()): + await result class EventSerializer(SpecificShapeSerializer): diff --git a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py index 79fdb60ca..c96634951 100644 --- a/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py +++ b/python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py @@ -21,13 +21,6 @@ class AsyncWriter(Protocol): async def write(self, data: bytes) -> None: ... -@runtime_checkable -class AsyncCloseable(Protocol): - """An object that can asynchronously close.""" - - async def close(self): ... - - # A union of all acceptable streaming blob types. Deserialized payloads will # always return a ByteStream, or AsyncByteStream if async is enabled. type StreamingBlob = SyncStreamingBlob | AsyncByteStream | AsyncIterable[bytes] diff --git a/python-packages/smithy-core/smithy_core/aio/types.py b/python-packages/smithy-core/smithy_core/aio/types.py index caedbcdb4..07d7a8bb0 100644 --- a/python-packages/smithy-core/smithy_core/aio/types.py +++ b/python-packages/smithy-core/smithy_core/aio/types.py @@ -110,10 +110,12 @@ def closed(self) -> bool: """Returns whether the stream is closed.""" return self._closed - def close(self) -> None: + async def close(self) -> None: """Closes the stream, as well as the underlying stream where possible.""" if (close := getattr(self._data, "close", None)) is not None: - close() + if asyncio.iscoroutine(result := close()): + await result + self._data = None self._closed = True @@ -244,10 +246,12 @@ def closed(self) -> bool: """Returns whether the stream is closed.""" return self._buffer.closed - def close(self) -> None: + async def close(self) -> None: """Closes the stream, as well as the underlying stream where possible.""" - if callable(close_fn := getattr(self._data_source, "close", None)): - close_fn() # pylint: disable=not-callable + if (close := getattr(self._data_source, "close", None)) is not None: + if asyncio.iscoroutine(result := close()): + await result + self._data_source = None self._buffer.close() diff --git a/python-packages/smithy-core/tests/unit/aio/test_types.py b/python-packages/smithy-core/tests/unit/aio/test_types.py index 330fa23e4..103845708 100644 --- a/python-packages/smithy-core/tests/unit/aio/test_types.py +++ b/python-packages/smithy-core/tests/unit/aio/test_types.py @@ -138,7 +138,23 @@ async def test_close_closeable_source() -> None: assert not reader.closed assert not source.closed - reader.close() + await reader.close() + + assert reader.closed + assert source.closed + + with pytest.raises(ValueError): + await reader.read() + + +async def test_close_async_closeable_source() -> None: + source = AsyncBytesReader(BytesIO(b"foo")) + reader = AsyncBytesReader(source) + + assert not reader.closed + assert not source.closed + + await reader.close() assert reader.closed assert source.closed @@ -152,7 +168,7 @@ async def test_close_non_closeable_source() -> None: reader = AsyncBytesReader(source) assert not reader.closed - reader.close() + await reader.close() assert reader.closed with pytest.raises(ValueError): @@ -167,7 +183,27 @@ async def test_seekable_close_closeable_source() -> None: assert not source.closed assert reader.tell() == 0 - reader.close() + await reader.close() + + assert reader.closed + assert source.closed + + with pytest.raises(ValueError): + await reader.read() + + with pytest.raises(ValueError): + reader.tell() + + +async def test_seekable_close_async_closeable_source() -> None: + source = AsyncBytesReader(BytesIO(b"foo")) + reader = SeekableAsyncBytesReader(source) + + assert not reader.closed + assert not source.closed + assert reader.tell() == 0 + + await reader.close() assert reader.closed assert source.closed @@ -185,7 +221,7 @@ async def test_seekable_close_non_closeable_source() -> None: assert not reader.closed assert reader.tell() == 0 - reader.close() + await reader.close() assert reader.closed with pytest.raises(ValueError):