Skip to content

Commit c11f1de

Browse files
Add sync/async close wrapper util
1 parent 2e0719e commit c11f1de

File tree

4 files changed

+44
-17
lines changed

4 files changed

+44
-17
lines changed

packages/smithy-core/src/smithy_core/aio/types.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ..exceptions import SmithyException
1111
from ..interfaces import BytesReader
1212
from .interfaces import AsyncByteStream, StreamingBlob
13+
from .utils import close
1314

1415
# The default chunk size for iterating streams.
1516
_DEFAULT_CHUNK_SIZE = 1024
@@ -115,10 +116,7 @@ def closed(self) -> bool:
115116
async def close(self) -> None:
116117
"""Closes the stream, as well as the underlying stream where possible."""
117118
self._closed = True
118-
if (close := getattr(self._data, "close", None)) is not None:
119-
if asyncio.iscoroutine(result := close()):
120-
await result
121-
119+
await close(self._data)
122120
self._data = None
123121

124122

@@ -250,12 +248,9 @@ def closed(self) -> bool:
250248

251249
async def close(self) -> None:
252250
"""Closes the stream, as well as the underlying stream where possible."""
253-
if (close := getattr(self._data_source, "close", None)) is not None:
254-
if asyncio.iscoroutine(result := close()):
255-
await result
256-
257-
self._data_source = None
258251
self._buffer.close()
252+
await close(self._data_source)
253+
self._data_source = None
259254

260255

261256
class _AsyncByteStreamIterator:

packages/smithy-core/src/smithy_core/aio/utils.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
from asyncio import sleep
3+
from asyncio import sleep, iscoroutine
44
from collections.abc import AsyncIterable, Iterable
5-
from typing import TypeVar
5+
from typing import Any
66

77
from ..exceptions import AsyncBodyException
88
from ..interfaces import BytesReader
99
from ..interfaces import StreamingBlob as SyncStreamingBlob
1010
from .interfaces import AsyncByteStream, StreamingBlob
1111

12-
_ListEl = TypeVar("_ListEl")
1312

14-
15-
async def async_list(lst: Iterable[_ListEl]) -> AsyncIterable[_ListEl]:
13+
async def async_list[E](lst: Iterable[E]) -> AsyncIterable[E]:
1614
"""Turn an Iterable into an AsyncIterable."""
1715
for x in lst:
1816
await sleep(0)
@@ -53,3 +51,10 @@ def read_streaming_blob(body: StreamingBlob) -> bytes:
5351
raise AsyncBodyException(
5452
f"Expected type {SyncStreamingBlob}, but was {type(body)}"
5553
)
54+
55+
56+
async def close(stream: Any) -> None:
57+
"""Close a stream, awaiting it if it's async."""
58+
if (close := getattr(stream, "close", None)) is not None:
59+
if iscoroutine(result := close()):
60+
await result
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from io import BytesIO
4+
5+
from smithy_core.aio.types import AsyncBytesProvider
6+
from smithy_core.aio.utils import close
7+
8+
9+
async def test_close_sync_closeable() -> None:
10+
stream = BytesIO()
11+
assert not stream.closed
12+
await close(stream)
13+
assert stream.closed
14+
15+
16+
async def test_close_async_closeable() -> None:
17+
stream = AsyncBytesProvider()
18+
assert not stream.closed
19+
await close(stream)
20+
assert stream.closed
21+
22+
23+
async def test_close_non_closeable() -> None:
24+
await close(b"foo")

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from smithy_core import interfaces as core_interfaces
3737
from smithy_core.aio.types import AsyncBytesReader
38+
from smithy_core.aio.utils import close
3839
from smithy_core.exceptions import MissingDependencyException
3940

4041
from .. import Field, Fields
@@ -361,7 +362,9 @@ async def _marshal_request(
361362
# asyncio.create_task we'll start the coroutine without having to
362363
# explicitly await it.
363364
crt_body = BufferableByteStream()
364-
body = AsyncBytesReader(body)
365+
366+
if not isinstance(body, AsyncIterable):
367+
body = AsyncBytesReader(body)
365368

366369
# Start the read task in the background.
367370
read_task = asyncio.create_task(self._consume_body_async(body, crt_body))
@@ -380,7 +383,7 @@ async def _marshal_request(
380383
return crt_request, crt_body
381384

382385
async def _consume_body_async(
383-
self, source: AsyncBytesReader, dest: "BufferableByteStream"
386+
self, source: AsyncIterable[bytes], dest: "BufferableByteStream"
384387
) -> None:
385388
try:
386389
async for chunk in source:
@@ -389,7 +392,7 @@ async def _consume_body_async(
389392
dest.close()
390393
raise
391394
finally:
392-
await source.close()
395+
await close(source)
393396
dest.end_stream()
394397

395398
def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient":

0 commit comments

Comments
 (0)