|
25 | 25 | HAS_CRT = False # type: ignore |
26 | 26 |
|
27 | 27 | from smithy_core import interfaces as core_interfaces |
| 28 | +from smithy_core.aio.types import AsyncBytesReader |
28 | 29 | from smithy_core.exceptions import MissingDependencyException |
29 | 30 |
|
30 | 31 | from .. import Field, Fields |
@@ -187,6 +188,7 @@ def __init__( |
187 | 188 | self._tls_ctx = crt_io.ClientTlsContext(crt_io.TlsContextOptions()) |
188 | 189 | self._socket_options = crt_io.SocketOptions() |
189 | 190 | self._connections: ConnectionPoolDict = {} |
| 191 | + self._async_reads: set[asyncio.Task[Any]] = set() |
190 | 192 |
|
191 | 193 | async def send( |
192 | 194 | self, |
@@ -293,12 +295,42 @@ async def _marshal_request( |
293 | 295 |
|
294 | 296 | path = self._render_path(request.destination) |
295 | 297 | headers = crt_http.HttpHeaders(headers_list) |
296 | | - body = BytesIO(await request.consume_body_async()) |
| 298 | + |
| 299 | + body = request.body |
| 300 | + if isinstance(body, bytes | bytearray): |
| 301 | + # If the body is already directly in memory, wrap in in a BytesIO to hand |
| 302 | + # off to CRT. |
| 303 | + crt_body = BytesIO(body) |
| 304 | + else: |
| 305 | + # If the body is async, or potentially very large, start up a task to read |
| 306 | + # it into the BytesIO object that CRT needs. By using asyncio.create_task |
| 307 | + # we'll start the coroutine without having to explicitly await it. |
| 308 | + crt_body = BytesIO() |
| 309 | + if not isinstance(body, AsyncIterable): |
| 310 | + # If the body isn't already an async iterable, wrap it in one. Objects |
| 311 | + # with read methods will be read in chunks so as not to exhaust memory. |
| 312 | + body = AsyncBytesReader(body) |
| 313 | + |
| 314 | + # Start the read task int the background. |
| 315 | + read_task = asyncio.create_task(self._consume_body_async(body, crt_body)) |
| 316 | + |
| 317 | + # Keep track of the read task so that it doesn't get garbage colllected, |
| 318 | + # and stop tracking it once it's done. |
| 319 | + self._async_reads.add(read_task) |
| 320 | + read_task.add_done_callback(self._async_reads.discard) |
297 | 321 |
|
298 | 322 | crt_request = crt_http.HttpRequest( |
299 | 323 | method=request.method, |
300 | 324 | path=path, |
301 | 325 | headers=headers, |
302 | | - body_stream=body, |
| 326 | + body_stream=crt_body, |
303 | 327 | ) |
304 | 328 | return crt_request |
| 329 | + |
| 330 | + async def _consume_body_async( |
| 331 | + self, source: AsyncIterable[bytes], dest: BytesIO |
| 332 | + ) -> None: |
| 333 | + async for chunk in source: |
| 334 | + dest.write(chunk) |
| 335 | + # Should we call close here? Or will that make the crt unable to read the last |
| 336 | + # chunk? |
0 commit comments