33# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
44# flake8: noqa: F811
55import asyncio
6+ from collections import deque
67from collections .abc import AsyncGenerator , AsyncIterable , Awaitable
78from concurrent .futures import Future
89from copy import deepcopy
9- from io import BytesIO
10+ from io import BytesIO , BufferedIOBase
1011from threading import Lock
1112from typing import TYPE_CHECKING , Any
1213
1718 from awscrt import http as crt_http
1819 from awscrt import io as crt_io
1920
21+ # Both of these are types that essentially are "castable to bytes/memoryview"
22+ # Unfortunately they're not exposed anywhere so we have to import them from
23+ # _typeshed.
24+ from _typeshed import WriteableBuffer , ReadableBuffer
25+
2026try :
2127 from awscrt import http as crt_http
2228 from awscrt import io as crt_io
@@ -304,7 +310,7 @@ async def _marshal_request(
304310 # If the body is async, or potentially very large, start up a task to read
305311 # it into the BytesIO object that CRT needs. By using asyncio.create_task
306312 # we'll start the coroutine without having to explicitly await it.
307- crt_body = BytesIO ()
313+ crt_body = BufferableByteStream ()
308314 if not isinstance (body , AsyncIterable ):
309315 # If the body isn't already an async iterable, wrap it in one. Objects
310316 # with read methods will be read in chunks so as not to exhaust memory.
@@ -327,15 +333,92 @@ async def _marshal_request(
327333 return crt_request
328334
329335 async def _consume_body_async (
330- self , source : AsyncIterable [bytes ], dest : BytesIO
336+ self , source : AsyncIterable [bytes ], dest : "BufferableByteStream"
331337 ) -> None :
332338 async for chunk in source :
333339 dest .write (chunk )
334- # Should we call close here? Or will that make the crt unable to read the last
335- # chunk?
340+ dest .end_stream ()
336341
337342 def __deepcopy__ (self , memo : Any ) -> "AWSCRTHTTPClient" :
338343 return AWSCRTHTTPClient (
339344 eventloop = self ._eventloop ,
340345 client_config = deepcopy (self ._config ),
341346 )
347+
348+
349+ # This is adapted from the transcribe streaming sdk
350+ class BufferableByteStream (BufferedIOBase ):
351+ """A non-blocking bytes buffer."""
352+
353+ def __init__ (self ) -> None :
354+ # We're always manipulating the front and back of the buffer, so a deque
355+ # will be much more efficient than a list.
356+ self ._chunks : deque [bytes ] = deque ()
357+ self ._closed = False
358+ self ._done = False
359+
360+ def read (self , size : int | None = - 1 ) -> bytes :
361+ if self ._closed :
362+ return b""
363+
364+ if len (self ._chunks ) == 0 :
365+ # When the CRT recieves this, it'll try again later.
366+ raise BlockingIOError ("read" )
367+
368+ # We could compile all the chunks here instead of just returning
369+ # the one, BUT the CRT will keep calling read until empty bytes
370+ # are returned. So it's actually better to just return one chunk
371+ # since combining them would have some potentially bad memory
372+ # usage issues.
373+ result = self ._chunks .popleft ()
374+ if size is not None and size > 0 :
375+ remainder = result [size :]
376+ result = result [:size ]
377+ if remainder :
378+ self ._chunks .appendleft (remainder )
379+
380+ if self ._done and len (self ._chunks ) == 0 :
381+ self .close ()
382+
383+ return result
384+
385+ def read1 (self , size : int = - 1 ) -> bytes :
386+ return self .read (size )
387+
388+ def readinto (self , buffer : "WriteableBuffer" ) -> int :
389+ if not isinstance (buffer , memoryview ):
390+ buffer = memoryview (buffer ).cast ("B" )
391+
392+ data = self .read (len (buffer )) # type: ignore
393+ n = len (data )
394+ buffer [:n ] = data
395+ return n
396+
397+ def write (self , buffer : "ReadableBuffer" ) -> int :
398+ if not isinstance (buffer , bytes ):
399+ raise ValueError (
400+ f"Unexpected value written to BufferableByteStream. "
401+ f"Only bytes are support but { type (buffer )} was provided."
402+ )
403+
404+ if self ._closed :
405+ raise IOError ("Stream is completed and doesn't support further writes." )
406+
407+ if buffer :
408+ self ._chunks .append (buffer )
409+ return len (buffer )
410+
411+ @property
412+ def closed (self ) -> bool :
413+ return self ._closed
414+
415+ def close (self ) -> None :
416+ self ._closed = True
417+ self ._done = True
418+
419+ # Clear out the remaining chunks so that they don't sit around in memory.
420+ self ._chunks .clear ()
421+
422+ def end_stream (self ) -> None :
423+ """End the stream, letting any remaining chunks be read before it is closed."""
424+ self ._done = True
0 commit comments