From ebc7dccad69d1f0db44be961b290afd04c451619 Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Fri, 7 Mar 2025 14:07:45 +0100 Subject: [PATCH] Add non-blocking buffer to interface with CRT --- .../smithy-http/src/smithy_http/aio/crt.py | 93 +++++++++++++++++- .../smithy-http/tests/unit/aio/test_crt.py | 97 ++++++++++++++++++- 2 files changed, 184 insertions(+), 6 deletions(-) diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 2dec88a06..e043088ba 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -3,10 +3,11 @@ # pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false # flake8: noqa: F811 import asyncio +from collections import deque from collections.abc import AsyncGenerator, AsyncIterable, Awaitable from concurrent.futures import Future from copy import deepcopy -from io import BytesIO +from io import BytesIO, BufferedIOBase from threading import Lock from typing import TYPE_CHECKING, Any @@ -17,6 +18,11 @@ from awscrt import http as crt_http from awscrt import io as crt_io + # Both of these are types that essentially are "castable to bytes/memoryview" + # Unfortunately they're not exposed anywhere so we have to import them from + # _typeshed. + from _typeshed import WriteableBuffer, ReadableBuffer + try: from awscrt import http as crt_http from awscrt import io as crt_io @@ -304,7 +310,7 @@ async def _marshal_request( # If the body is async, or potentially very large, start up a task to read # it into the BytesIO object that CRT needs. By using asyncio.create_task # we'll start the coroutine without having to explicitly await it. - crt_body = BytesIO() + crt_body = BufferableByteStream() if not isinstance(body, AsyncIterable): # If the body isn't already an async iterable, wrap it in one. Objects # with read methods will be read in chunks so as not to exhaust memory. @@ -327,15 +333,92 @@ async def _marshal_request( return crt_request async def _consume_body_async( - self, source: AsyncIterable[bytes], dest: BytesIO + self, source: AsyncIterable[bytes], dest: "BufferableByteStream" ) -> None: async for chunk in source: dest.write(chunk) - # Should we call close here? Or will that make the crt unable to read the last - # chunk? + dest.end_stream() def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient": return AWSCRTHTTPClient( eventloop=self._eventloop, client_config=deepcopy(self._config), ) + + +# This is adapted from the transcribe streaming sdk +class BufferableByteStream(BufferedIOBase): + """A non-blocking bytes buffer.""" + + def __init__(self) -> None: + # We're always manipulating the front and back of the buffer, so a deque + # will be much more efficient than a list. + self._chunks: deque[bytes] = deque() + self._closed = False + self._done = False + + def read(self, size: int | None = -1) -> bytes: + if self._closed: + return b"" + + if len(self._chunks) == 0: + # When the CRT recieves this, it'll try again later. + raise BlockingIOError("read") + + # We could compile all the chunks here instead of just returning + # the one, BUT the CRT will keep calling read until empty bytes + # are returned. So it's actually better to just return one chunk + # since combining them would have some potentially bad memory + # usage issues. + result = self._chunks.popleft() + if size is not None and size > 0: + remainder = result[size:] + result = result[:size] + if remainder: + self._chunks.appendleft(remainder) + + if self._done and len(self._chunks) == 0: + self.close() + + return result + + def read1(self, size: int = -1) -> bytes: + return self.read(size) + + def readinto(self, buffer: "WriteableBuffer") -> int: + if not isinstance(buffer, memoryview): + buffer = memoryview(buffer).cast("B") + + data = self.read(len(buffer)) # type: ignore + n = len(data) + buffer[:n] = data + return n + + def write(self, buffer: "ReadableBuffer") -> int: + if not isinstance(buffer, bytes): + raise ValueError( + f"Unexpected value written to BufferableByteStream. " + f"Only bytes are support but {type(buffer)} was provided." + ) + + if self._closed: + raise IOError("Stream is completed and doesn't support further writes.") + + if buffer: + self._chunks.append(buffer) + return len(buffer) + + @property + def closed(self) -> bool: + return self._closed + + def close(self) -> None: + self._closed = True + self._done = True + + # Clear out the remaining chunks so that they don't sit around in memory. + self._chunks.clear() + + def end_stream(self) -> None: + """End the stream, letting any remaining chunks be read before it is closed.""" + self._done = True diff --git a/packages/smithy-http/tests/unit/aio/test_crt.py b/packages/smithy-http/tests/unit/aio/test_crt.py index 1afabc3d8..3ca5197cc 100644 --- a/packages/smithy-http/tests/unit/aio/test_crt.py +++ b/packages/smithy-http/tests/unit/aio/test_crt.py @@ -1,8 +1,103 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 from copy import deepcopy -from smithy_http.aio.crt import AWSCRTHTTPClient +import pytest + +from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream def test_deepcopy_client() -> None: client = AWSCRTHTTPClient() deepcopy(client) + + +def test_stream_write() -> None: + stream = BufferableByteStream() + stream.write(b"foo") + assert stream.read() == b"foo" + + +def test_stream_reads_individual_chunks() -> None: + stream = BufferableByteStream() + stream.write(b"foo") + stream.write(b"bar") + assert stream.read() == b"foo" + assert stream.read() == b"bar" + + +def test_stream_empty_read() -> None: + stream = BufferableByteStream() + with pytest.raises(BlockingIOError): + stream.read() + + +def test_stream_partial_chunk_read() -> None: + stream = BufferableByteStream() + stream.write(b"foobar") + assert stream.read(3) == b"foo" + assert stream.read() == b"bar" + + +def test_stream_write_empty_bytes() -> None: + stream = BufferableByteStream() + stream.write(b"") + stream.write(b"foo") + stream.write(b"") + assert stream.read() == b"foo" + + +def test_stream_write_non_bytes() -> None: + stream = BufferableByteStream() + with pytest.raises(ValueError): + stream.write(memoryview(b"foo")) + + +def test_closed_stream_write() -> None: + stream = BufferableByteStream() + stream.close() + with pytest.raises(IOError): + stream.write(b"foo") + + +def test_closed_stream_read() -> None: + stream = BufferableByteStream() + stream.write(b"foo") + stream.close() + assert stream.read() == b"" + + +def test_stream_read1() -> None: + stream = BufferableByteStream() + stream.write(b"foo") + stream.write(b"bar") + assert stream.read1() == b"foo" + assert stream.read1() == b"bar" + with pytest.raises(BlockingIOError): + stream.read() + + +def test_stream_readinto_memoryview() -> None: + buffer = memoryview(bytearray(b" ")) + stream = BufferableByteStream() + stream.write(b"foobar") + stream.readinto(buffer) + assert bytes(buffer) == b"foo" + + +def test_stream_readinto_bytearray() -> None: + buffer = bytearray(b" ") + stream = BufferableByteStream() + stream.write(b"foobar") + stream.readinto(buffer) + assert bytes(buffer) == b"foo" + + +def test_end_stream() -> None: + stream = BufferableByteStream() + stream.write(b"foo") + stream.end_stream() + + assert not stream.closed + assert stream.read() == b"foo" + assert stream.closed