diff --git a/python-packages/smithy-http/smithy_http/aio/crt.py b/python-packages/smithy-http/smithy_http/aio/crt.py index b064bab5c..8a4dab6d0 100644 --- a/python-packages/smithy-http/smithy_http/aio/crt.py +++ b/python-packages/smithy-http/smithy_http/aio/crt.py @@ -25,7 +25,7 @@ HAS_CRT = False # type: ignore from smithy_core import interfaces as core_interfaces -from smithy_core.aio.types import AsyncBytesReader +from smithy_core.aio import interfaces as core_aio_interfaces from smithy_core.exceptions import MissingDependencyException from .. import Field, Fields @@ -188,7 +188,6 @@ def __init__( self._tls_ctx = crt_io.ClientTlsContext(crt_io.TlsContextOptions()) self._socket_options = crt_io.SocketOptions() self._connections: ConnectionPoolDict = {} - self._async_reads: set[asyncio.Task[Any]] = set() async def send( self, @@ -301,23 +300,14 @@ async def _marshal_request( # If the body is already directly in memory, wrap in a BytesIO to hand # off to CRT. crt_body = BytesIO(body) + elif not self._is_sync_stream(body): + # If the body is an async stream.... read it all into memory. This is + # very unfortunate, but necessary because the CRT doesn't currently + # have the capability to read async. We will likely have to implment + # this capability into it ourselves, or implment a thread-based wrapper. + crt_body = BytesIO(await request.consume_body_async()) else: - # If the body is async, or potentially very large, start up a task to read - # it into the BytesIO object that CRT needs. By using asyncio.create_task - # we'll start the coroutine without having to explicitly await it. - crt_body = BytesIO() - if not isinstance(body, AsyncIterable): - # If the body isn't already an async iterable, wrap it in one. Objects - # with read methods will be read in chunks so as not to exhaust memory. - body = AsyncBytesReader(body) - - # Start the read task in the background. - read_task = asyncio.create_task(self._consume_body_async(body, crt_body)) - - # Keep track of the read task so that it doesn't get garbage colllected, - # and stop tracking it once it's done. - self._async_reads.add(read_task) - read_task.add_done_callback(self._async_reads.discard) + crt_body = body crt_request = crt_http.HttpRequest( method=request.method, @@ -327,10 +317,6 @@ async def _marshal_request( ) return crt_request - async def _consume_body_async( - self, source: AsyncIterable[bytes], dest: BytesIO - ) -> None: - async for chunk in source: - dest.write(chunk) - # Should we call close here? Or will that make the crt unable to read the last - # chunk? + def _is_sync_stream(self, body: core_aio_interfaces.StreamingBlob): + read = getattr(body, "read") + return read is not None and not asyncio.iscoroutinefunction(read)