Skip to content

Commit 2848899

Browse files
Check for async when closing
Protocol isinstance does *not* check whether a function is sync or not. This adds in those checks to the various closeable checks.
1 parent fc3257e commit 2848899

File tree

5 files changed

+59
-22
lines changed

5 files changed

+59
-22
lines changed

python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
import datetime
45
from collections.abc import Callable
56

6-
from smithy_core.aio.interfaces import AsyncByteStream, AsyncCloseable
7+
from smithy_core.aio.interfaces import AsyncByteStream
78
from smithy_core.codecs import Codec
89
from smithy_core.deserializers import (
910
DeserializeableShape,
@@ -45,8 +46,9 @@ async def receive(self) -> E | None:
4546
return self._deserializer(deserializer)
4647

4748
async def close(self) -> None:
48-
if isinstance(self._source, AsyncCloseable):
49-
await self._source.close()
49+
if (close := getattr(self._source, "close", None)) is not None:
50+
if asyncio.iscoroutine(result := close()):
51+
await result
5052

5153

5254
class EventDeserializer(SpecificShapeDeserializer):

python-packages/aws-event-stream/aws_event_stream/_private/serializers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
import datetime
45
from collections.abc import Callable, Iterator
56
from contextlib import contextmanager
67
from io import BytesIO
78
from typing import Never
89

9-
from smithy_core.aio.interfaces import AsyncCloseable, AsyncWriter
10+
from smithy_core.aio.interfaces import AsyncWriter
1011
from smithy_core.codecs import Codec
1112
from smithy_core.exceptions import ExpectationNotMetException
1213
from smithy_core.schemas import Schema
@@ -64,8 +65,9 @@ async def send(self, event: E) -> None:
6465
await self._writer.write(result.encode())
6566

6667
async def close(self) -> None:
67-
if isinstance(self._writer, AsyncCloseable):
68-
await self._writer.close()
68+
if (close := getattr(self._writer, "close", None)) is not None:
69+
if asyncio.iscoroutine(result := close()):
70+
await result
6971

7072

7173
class EventSerializer(SpecificShapeSerializer):

python-packages/smithy-core/smithy_core/aio/interfaces/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,6 @@ class AsyncWriter(Protocol):
2121
async def write(self, data: bytes) -> None: ...
2222

2323

24-
@runtime_checkable
25-
class AsyncCloseable(Protocol):
26-
"""An object that can asynchronously close."""
27-
28-
async def close(self): ...
29-
30-
3124
# A union of all acceptable streaming blob types. Deserialized payloads will
3225
# always return a ByteStream, or AsyncByteStream if async is enabled.
3326
type StreamingBlob = SyncStreamingBlob | AsyncByteStream | AsyncIterable[bytes]

python-packages/smithy-core/smithy_core/aio/types.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,12 @@ def closed(self) -> bool:
110110
"""Returns whether the stream is closed."""
111111
return self._closed
112112

113-
def close(self) -> None:
113+
async def close(self) -> None:
114114
"""Closes the stream, as well as the underlying stream where possible."""
115115
if (close := getattr(self._data, "close", None)) is not None:
116-
close()
116+
if asyncio.iscoroutine(result := close()):
117+
await result
118+
117119
self._data = None
118120
self._closed = True
119121

@@ -244,10 +246,12 @@ def closed(self) -> bool:
244246
"""Returns whether the stream is closed."""
245247
return self._buffer.closed
246248

247-
def close(self) -> None:
249+
async def close(self) -> None:
248250
"""Closes the stream, as well as the underlying stream where possible."""
249-
if callable(close_fn := getattr(self._data_source, "close", None)):
250-
close_fn() # pylint: disable=not-callable
251+
if (close := getattr(self._data_source, "close", None)) is not None:
252+
if asyncio.iscoroutine(result := close()):
253+
await result
254+
251255
self._data_source = None
252256
self._buffer.close()
253257

python-packages/smithy-core/tests/unit/aio/test_types.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,23 @@ async def test_close_closeable_source() -> None:
138138
assert not reader.closed
139139
assert not source.closed
140140

141-
reader.close()
141+
await reader.close()
142+
143+
assert reader.closed
144+
assert source.closed
145+
146+
with pytest.raises(ValueError):
147+
await reader.read()
148+
149+
150+
async def test_close_async_closeable_source() -> None:
151+
source = AsyncBytesReader(BytesIO(b"foo"))
152+
reader = AsyncBytesReader(source)
153+
154+
assert not reader.closed
155+
assert not source.closed
156+
157+
await reader.close()
142158

143159
assert reader.closed
144160
assert source.closed
@@ -152,7 +168,7 @@ async def test_close_non_closeable_source() -> None:
152168
reader = AsyncBytesReader(source)
153169

154170
assert not reader.closed
155-
reader.close()
171+
await reader.close()
156172
assert reader.closed
157173

158174
with pytest.raises(ValueError):
@@ -167,7 +183,27 @@ async def test_seekable_close_closeable_source() -> None:
167183
assert not source.closed
168184
assert reader.tell() == 0
169185

170-
reader.close()
186+
await reader.close()
187+
188+
assert reader.closed
189+
assert source.closed
190+
191+
with pytest.raises(ValueError):
192+
await reader.read()
193+
194+
with pytest.raises(ValueError):
195+
reader.tell()
196+
197+
198+
async def test_seekable_close_async_closeable_source() -> None:
199+
source = AsyncBytesReader(BytesIO(b"foo"))
200+
reader = SeekableAsyncBytesReader(source)
201+
202+
assert not reader.closed
203+
assert not source.closed
204+
assert reader.tell() == 0
205+
206+
await reader.close()
171207

172208
assert reader.closed
173209
assert source.closed
@@ -185,7 +221,7 @@ async def test_seekable_close_non_closeable_source() -> None:
185221

186222
assert not reader.closed
187223
assert reader.tell() == 0
188-
reader.close()
224+
await reader.close()
189225
assert reader.closed
190226

191227
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)