Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import datetime
from collections.abc import Callable

from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable
from smithy_core.aio.interfaces import AsyncByteStream
from smithy_core.codecs import Codec
from smithy_core.deserializers import (
DeserializeableShape,
Expand Down Expand Up @@ -45,8 +46,9 @@ async def receive(self) -> E | None:
return self._deserializer(deserializer)

async def close(self) -> None:
if isinstance(self._source, AsyncCloseable):
await self._source.close()
if (close := getattr(self._source, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result


class EventDeserializer(SpecificShapeDeserializer):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import datetime
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from io import BytesIO
from typing import Never

from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter
from smithy_core.aio.interfaces import AsyncWriter
from smithy_core.codecs import Codec
from smithy_core.exceptions import ExpectationNotMetException
from smithy_core.schemas import Schema
Expand Down Expand Up @@ -64,8 +65,9 @@ async def send(self, event: E) -> None:
await self._writer.write(result.encode())

async def close(self) -> None:
if isinstance(self._writer, AsyncCloseable):
await self._writer.close()
if (close := getattr(self._writer, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result


class EventSerializer(SpecificShapeSerializer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ class AsyncWriter(Protocol):
async def write(self, data: bytes) -> None: ...


@runtime_checkable
class AsyncCloseable(Protocol):
"""An object that can asynchronously close."""

async def close(self): ...


# A union of all acceptable streaming blob types. Deserialized payloads will
# always return a ByteStream, or AsyncByteStream if async is enabled.
type StreamingBlob = SyncStreamingBlob | AsyncByteStream | AsyncIterable[bytes]
Expand Down
14 changes: 9 additions & 5 deletions python-packages/smithy-core/smithy_core/aio/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def closed(self) -> bool:
"""Returns whether the stream is closed."""
return self._closed

def close(self) -> None:
async def close(self) -> None:
"""Closes the stream, as well as the underlying stream where possible."""
if (close := getattr(self._data, "close", None)) is not None:
close()
if asyncio.iscoroutine(result := close()):
await result

self._data = None
self._closed = True

Expand Down Expand Up @@ -244,10 +246,12 @@ def closed(self) -> bool:
"""Returns whether the stream is closed."""
return self._buffer.closed

def close(self) -> None:
async def close(self) -> None:
"""Closes the stream, as well as the underlying stream where possible."""
if callable(close_fn := getattr(self._data_source, "close", None)):
close_fn() # pylint: disable=not-callable
if (close := getattr(self._data_source, "close", None)) is not None:
if asyncio.iscoroutine(result := close()):
await result

self._data_source = None
self._buffer.close()

Expand Down
44 changes: 40 additions & 4 deletions python-packages/smithy-core/tests/unit/aio/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,23 @@ async def test_close_closeable_source() -> None:
assert not reader.closed
assert not source.closed

reader.close()
await reader.close()

assert reader.closed
assert source.closed

with pytest.raises(ValueError):
await reader.read()


async def test_close_async_closeable_source() -> None:
source = AsyncBytesReader(BytesIO(b"foo"))
reader = AsyncBytesReader(source)

assert not reader.closed
assert not source.closed

await reader.close()

assert reader.closed
assert source.closed
Expand All @@ -152,7 +168,7 @@ async def test_close_non_closeable_source() -> None:
reader = AsyncBytesReader(source)

assert not reader.closed
reader.close()
await reader.close()
assert reader.closed

with pytest.raises(ValueError):
Expand All @@ -167,7 +183,27 @@ async def test_seekable_close_closeable_source() -> None:
assert not source.closed
assert reader.tell() == 0

reader.close()
await reader.close()

assert reader.closed
assert source.closed

with pytest.raises(ValueError):
await reader.read()

with pytest.raises(ValueError):
reader.tell()


async def test_seekable_close_async_closeable_source() -> None:
source = AsyncBytesReader(BytesIO(b"foo"))
reader = SeekableAsyncBytesReader(source)

assert not reader.closed
assert not source.closed
assert reader.tell() == 0

await reader.close()

assert reader.closed
assert source.closed
Expand All @@ -185,7 +221,7 @@ async def test_seekable_close_non_closeable_source() -> None:

assert not reader.closed
assert reader.tell() == 0
reader.close()
await reader.close()
assert reader.closed

with pytest.raises(ValueError):
Expand Down
Loading