Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
# flake8: noqa: F811
import asyncio
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable
from concurrent.futures import Future
from copy import deepcopy
from io import BytesIO
from io import BytesIO, BufferedIOBase
from threading import Lock
from typing import TYPE_CHECKING, Any

Expand All @@ -17,6 +18,11 @@
from awscrt import http as crt_http
from awscrt import io as crt_io

# Both of these are types that essentially are "castable to bytes/memoryview"
# Unfortunately they're not exposed anywhere so we have to import them from
# _typeshed.
from _typeshed import WriteableBuffer, ReadableBuffer

try:
from awscrt import http as crt_http
from awscrt import io as crt_io
Expand Down Expand Up @@ -304,7 +310,7 @@ async def _marshal_request(
# If the body is async, or potentially very large, start up a task to read
# it into the BytesIO object that CRT needs. By using asyncio.create_task
# we'll start the coroutine without having to explicitly await it.
crt_body = BytesIO()
crt_body = BufferableByteStream()
if not isinstance(body, AsyncIterable):
# If the body isn't already an async iterable, wrap it in one. Objects
# with read methods will be read in chunks so as not to exhaust memory.
Expand All @@ -327,15 +333,92 @@ async def _marshal_request(
return crt_request

async def _consume_body_async(
self, source: AsyncIterable[bytes], dest: BytesIO
self, source: AsyncIterable[bytes], dest: "BufferableByteStream"
) -> None:
async for chunk in source:
dest.write(chunk)
# Should we call close here? Or will that make the crt unable to read the last
# chunk?
dest.end_stream()

def __deepcopy__(self, memo: Any) -> "AWSCRTHTTPClient":
return AWSCRTHTTPClient(
eventloop=self._eventloop,
client_config=deepcopy(self._config),
)


# This is adapted from the transcribe streaming sdk
class BufferableByteStream(BufferedIOBase):
"""A non-blocking bytes buffer."""

def __init__(self) -> None:
# We're always manipulating the front and back of the buffer, so a deque
# will be much more efficient than a list.
self._chunks: deque[bytes] = deque()
self._closed = False
self._done = False

def read(self, size: int | None = -1) -> bytes:
if self._closed:
return b""

if len(self._chunks) == 0:
# When the CRT recieves this, it'll try again later.
raise BlockingIOError("read")

# We could compile all the chunks here instead of just returning
# the one, BUT the CRT will keep calling read until empty bytes
# are returned. So it's actually better to just return one chunk
# since combining them would have some potentially bad memory
# usage issues.
result = self._chunks.popleft()
if size is not None and size > 0:
remainder = result[size:]
result = result[:size]
if remainder:
self._chunks.appendleft(remainder)

if self._done and len(self._chunks) == 0:
self.close()

return result

def read1(self, size: int = -1) -> bytes:
return self.read(size)

def readinto(self, buffer: "WriteableBuffer") -> int:
if not isinstance(buffer, memoryview):
buffer = memoryview(buffer).cast("B")

data = self.read(len(buffer)) # type: ignore
n = len(data)
buffer[:n] = data
return n

def write(self, buffer: "ReadableBuffer") -> int:
if not isinstance(buffer, bytes):
raise ValueError(
f"Unexpected value written to BufferableByteStream. "
f"Only bytes are support but {type(buffer)} was provided."
)

if self._closed:
raise IOError("Stream is completed and doesn't support further writes.")

if buffer:
self._chunks.append(buffer)
return len(buffer)

@property
def closed(self) -> bool:
return self._closed

def close(self) -> None:
self._closed = True
self._done = True

# Clear out the remaining chunks so that they don't sit around in memory.
self._chunks.clear()

def end_stream(self) -> None:
"""End the stream, letting any remaining chunks be read before it is closed."""
self._done = True
97 changes: 96 additions & 1 deletion packages/smithy-http/tests/unit/aio/test_crt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,103 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from copy import deepcopy

from smithy_http.aio.crt import AWSCRTHTTPClient
import pytest

from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream


def test_deepcopy_client() -> None:
client = AWSCRTHTTPClient()
deepcopy(client)


def test_stream_write() -> None:
stream = BufferableByteStream()
stream.write(b"foo")
assert stream.read() == b"foo"


def test_stream_reads_individual_chunks() -> None:
stream = BufferableByteStream()
stream.write(b"foo")
stream.write(b"bar")
assert stream.read() == b"foo"
assert stream.read() == b"bar"


def test_stream_empty_read() -> None:
stream = BufferableByteStream()
with pytest.raises(BlockingIOError):
stream.read()


def test_stream_partial_chunk_read() -> None:
stream = BufferableByteStream()
stream.write(b"foobar")
assert stream.read(3) == b"foo"
assert stream.read() == b"bar"


def test_stream_write_empty_bytes() -> None:
stream = BufferableByteStream()
stream.write(b"")
stream.write(b"foo")
stream.write(b"")
assert stream.read() == b"foo"


def test_stream_write_non_bytes() -> None:
stream = BufferableByteStream()
with pytest.raises(ValueError):
stream.write(memoryview(b"foo"))


def test_closed_stream_write() -> None:
stream = BufferableByteStream()
stream.close()
with pytest.raises(IOError):
stream.write(b"foo")


def test_closed_stream_read() -> None:
stream = BufferableByteStream()
stream.write(b"foo")
stream.close()
assert stream.read() == b""


def test_stream_read1() -> None:
stream = BufferableByteStream()
stream.write(b"foo")
stream.write(b"bar")
assert stream.read1() == b"foo"
assert stream.read1() == b"bar"
with pytest.raises(BlockingIOError):
stream.read()


def test_stream_readinto_memoryview() -> None:
buffer = memoryview(bytearray(b" "))
stream = BufferableByteStream()
stream.write(b"foobar")
stream.readinto(buffer)
assert bytes(buffer) == b"foo"


def test_stream_readinto_bytearray() -> None:
buffer = bytearray(b" ")
stream = BufferableByteStream()
stream.write(b"foobar")
stream.readinto(buffer)
assert bytes(buffer) == b"foo"


def test_end_stream() -> None:
stream = BufferableByteStream()
stream.write(b"foo")
stream.end_stream()

assert not stream.closed
assert stream.read() == b"foo"
assert stream.closed
Loading