Skip to content

Commit b7f51bf

Browse files
Add an asynchronous bytes provider
This adds a class that allows bytes to be asynchronously exchanged. This primarily serves the purpose of enabling event streaming. Event streams will create a provider under the hood that will be written to. That provider will then be passed as the trasnport request body, which will read it as an async bytes iterable.
1 parent 0fc3746 commit b7f51bf

File tree

2 files changed

+323
-2
lines changed

2 files changed

+323
-2
lines changed

python-packages/smithy-core/smithy_core/aio/types.py

Lines changed: 153 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
from asyncio import iscoroutinefunction
45
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
56
from io import BytesIO
6-
from typing import Self, cast
7+
from typing import Any, Self, cast
78

9+
from ..exceptions import SmithyException
810
from ..interfaces import BytesReader
911
from .interfaces import AsyncByteStream, StreamingBlob
1012

@@ -271,3 +273,153 @@ async def __anext__(self) -> bytes:
271273
if data:
272274
return data
273275
raise StopAsyncIteration
276+
277+
278+
class AsyncBytesProvider:
279+
"""A buffer that allows chunks of bytes to be exchanged asynchronously.
280+
281+
Bytes are written in chunks to an internal buffer, that is then drained via an async
282+
iterator.
283+
"""
284+
285+
def __init__(
286+
self, intial_data: bytes | None = None, max_buffered_chunks: int = 16
287+
) -> None:
288+
"""Initialize the AsyncBytesProvider.
289+
290+
:param initial_data: An initial chunk of bytes to make available.
291+
:param max_buffered_chunks: The maximum number of chunks of data to buffer.
292+
Calls to ``write`` will block until the number of chunks is less than this
293+
number. Default is 16.
294+
"""
295+
if intial_data is not None:
296+
self._data = [intial_data]
297+
else:
298+
self._data = []
299+
300+
if max_buffered_chunks < 1:
301+
raise ValueError(
302+
"The maximum number of buffered chunks must be greater than 0."
303+
)
304+
305+
self._closed = False
306+
self._closing = False
307+
self._flushing = False
308+
self._max_buffered_chunks = max_buffered_chunks
309+
310+
# Create a Condition to synchronize access to the data chunk pool.
311+
self._data_condition = asyncio.Condition()
312+
313+
async def write(self, data: bytes) -> None:
314+
if self._closed:
315+
raise SmithyException("Attempted to write to a closed provider.")
316+
317+
# Acquire a lock on the data buffer, releasing it automatically when the
318+
# block exits.
319+
async with self._data_condition:
320+
321+
# Wait for the number of chunks in the buffer to be less than the
322+
# specified maximum. This also releases the lock until the condition
323+
# is met.
324+
await self._data_condition.wait_for(self._can_write)
325+
326+
# The provider could have been closed while waiting to write, so an
327+
# additional check is done here for safety.
328+
if self._closed or self._closing:
329+
# Notify to allow other coroutines to check their conditions.
330+
self._data_condition.notify()
331+
raise SmithyException(
332+
"Attempted to write to a closed or closing provider."
333+
)
334+
335+
# Add a new chunk of data to the buffer and notify the next waiting
336+
# coroutine.
337+
self._data.append(data)
338+
self._data_condition.notify()
339+
340+
def _can_write(self) -> bool:
341+
return (
342+
self._closed
343+
or self._closing
344+
or (len(self._data) < self._max_buffered_chunks and not self._flushing)
345+
)
346+
347+
@property
348+
def closed(self) -> bool:
349+
"""Returns whether the provider is closed."""
350+
return self._closed
351+
352+
async def flush(self) -> None:
353+
"""Waits for all buffered data to be consumed."""
354+
if self._closed:
355+
return
356+
357+
# Acquire a lock on the data buffer, releasing it automatically when the
358+
# block exits.
359+
async with self._data_condition:
360+
# Block writes
361+
self._flushing = True
362+
363+
# Wait for the stream to be closed or for the data buffer to be empty.
364+
await self._data_condition.wait_for(lambda: len(self._data) == 0)
365+
366+
# Unblock writes
367+
self._flushing = False
368+
369+
async def close(self, flush: bool = False) -> None:
370+
"""Closes the provider.
371+
372+
Pending writing tasks queued after this will fail, so such tasks should be
373+
awaited before this. Write tasks queued before this may succeed, however.
374+
375+
:param flush: Whether to flush buffered data before closing. If false, all
376+
buffered data will be lost. Default is False.
377+
"""
378+
if self._closed:
379+
return
380+
381+
# Acquire a lock on the data buffer, releasing it automatically when the
382+
# block exits. Notably this will not wait on a condition to move forward.
383+
async with self._data_condition:
384+
self._closing = True
385+
if flush:
386+
await self._data_condition.wait_for(lambda: len(self._data) == 0)
387+
else:
388+
# Clear out any pending data, freeing up memory.
389+
self._data.clear()
390+
391+
self._closed = True
392+
self._closing = False
393+
394+
# Notify all waiting coroutines that the provider has closed.
395+
self._data_condition.notify_all()
396+
397+
def __aiter__(self) -> Self:
398+
return self
399+
400+
async def __anext__(self) -> bytes:
401+
# Acquire a lock on the data buffer, releasing it automatically when the
402+
# block exits.
403+
async with self._data_condition:
404+
405+
# Wait for the stream to be closed or for the data buffer to be non-empty.
406+
# This also releases the lock until the condition is met.
407+
await self._data_condition.wait_for(
408+
lambda: self._closed or len(self._data) > 0
409+
)
410+
411+
# If the provider is closed, end the iteration.
412+
if self._closed:
413+
raise StopAsyncIteration
414+
415+
# Pop the next chunk of data from the buffer, then notify any waiting
416+
# coroutines, returning immediately after.
417+
result = self._data.pop()
418+
self._data_condition.notify()
419+
return result
420+
421+
async def __aenter__(self) -> Self:
422+
return self
423+
424+
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
425+
await self.close()

python-packages/smithy-core/tests/unit/aio/test_types.py

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
from io import BytesIO
45
from typing import Self
56

67
import pytest
78

8-
from smithy_core.aio.types import AsyncBytesReader, SeekableAsyncBytesReader
9+
from smithy_core.aio.types import (
10+
AsyncBytesProvider,
11+
AsyncBytesReader,
12+
SeekableAsyncBytesReader,
13+
)
14+
from smithy_core.exceptions import SmithyException
915

1016

1117
class _AsyncIteratorWrapper:
@@ -270,3 +276,166 @@ async def test_seekable_iter_chunks() -> None:
270276

271277
assert reader.tell() == 3
272278
assert result == b"foo"
279+
280+
281+
async def test_provider_requires_positive_max_chunks() -> None:
282+
with pytest.raises(ValueError):
283+
AsyncBytesProvider(max_buffered_chunks=-1)
284+
285+
286+
async def drain_provider(provider: AsyncBytesProvider, dest: list[bytes]) -> None:
287+
async for chunk in provider:
288+
dest.append(chunk)
289+
290+
291+
async def test_provider_reads_initial_data() -> None:
292+
provider = AsyncBytesProvider(intial_data=b"foo")
293+
result: list[bytes] = []
294+
295+
# Start the read task in the background.
296+
read_task = asyncio.create_task(drain_provider(provider, result))
297+
298+
# Wait for the buffer to drain. At that point all the data should
299+
# be read, but the read dask won't actually be complete yet
300+
# because it's still waiting on future data.
301+
await provider.flush()
302+
assert result == [b"foo"]
303+
assert not read_task.done()
304+
305+
# Now actually close the provider, which will let the read task
306+
# complete.
307+
await provider.close()
308+
await read_task
309+
310+
# The result should not have changed
311+
assert result == [b"foo"]
312+
313+
314+
async def test_provider_reads_written_data() -> None:
315+
provider = AsyncBytesProvider()
316+
result: list[bytes] = []
317+
318+
# Start the read task in the background.
319+
read_task = asyncio.create_task(drain_provider(provider, result))
320+
await provider.write(b"foo")
321+
322+
# Wait for the buffer to drain. At that point all the data should
323+
# be read, but the read dask won't actually be complete yet
324+
# because it's still waiting on future data.
325+
await provider.flush()
326+
assert result == [b"foo"]
327+
assert not read_task.done()
328+
329+
# Now actually close the provider, which will let the read task
330+
# complete.
331+
await provider.close()
332+
await read_task
333+
334+
# The result should not have changed
335+
assert result == [b"foo"]
336+
337+
338+
async def test_close_stops_writes() -> None:
339+
provider = AsyncBytesProvider()
340+
await provider.close()
341+
with pytest.raises(SmithyException):
342+
await provider.write(b"foo")
343+
344+
345+
async def test_close_deletes_buffered_data() -> None:
346+
provider = AsyncBytesProvider(b"foo")
347+
await provider.close()
348+
result: list[bytes] = []
349+
await drain_provider(provider, result)
350+
assert result == []
351+
352+
# We weren't able to read data, which is what we want. But here we dig into
353+
# the internals to be sure that the buffer is clear and no data is haning
354+
# around.
355+
assert provider._data == [] # type: ignore
356+
357+
358+
async def test_only_max_chunks_buffered() -> None:
359+
# Initialize the provider with a max buffer of one and immediately have it
360+
# filled with an initial chunk.
361+
provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1)
362+
363+
# Schedule a write task. Using create_task immediately enqueues it, though it
364+
# won't start executing until its turn in the loop.
365+
write_task = asyncio.create_task(provider.write(b"bar"))
366+
367+
# Suspend the current coroutine so the write task can take over. It shouldn't
368+
# complete because it should be waiting on the buffer to drain. One tenth of
369+
# a second is way more than enough time for it to complete under normal
370+
# circumstances.
371+
await asyncio.sleep(0.1)
372+
assert not write_task.done()
373+
374+
# Now begin the read task in the background. Since it's draining the buffer, the
375+
# write task will become unblocked.
376+
result: list[bytes] = []
377+
read_task = asyncio.create_task(drain_provider(provider, result))
378+
379+
# The read task won't be done until we close the provider, but the write task
380+
# should be able to complete now.
381+
await write_task
382+
383+
# The write task and read task don't necessarily complete at the same time,
384+
# so we wait until the buffer is empty here.
385+
await provider.flush()
386+
assert result == [b"foo", b"bar"]
387+
388+
# Now we can close the provider and wait for the read task to end.
389+
await provider.close()
390+
await read_task
391+
392+
393+
async def test_close_stops_queued_writes() -> None:
394+
# Initialize the provider with a max buffer of one and immediately have it
395+
# filled with an initial chunk.
396+
provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1)
397+
398+
# Schedule a write task. Using create_task immediately enqueues it, though it
399+
# can't complete until the buffer is free.
400+
write_task = asyncio.create_task(provider.write(b"bar"))
401+
402+
# Now close the provider. The write task will raise an error.
403+
await provider.close()
404+
405+
with pytest.raises(SmithyException):
406+
await write_task
407+
408+
409+
async def test_close_with_flush() -> None:
410+
# Initialize the provider with a max buffer of one and immediately have it
411+
# filled with an initial chunk.
412+
provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1)
413+
414+
# Schedule a write task. Using create_task immediately enqueues it, though it
415+
# can't complete until the buffer is free.
416+
write_task = asyncio.create_task(provider.write(b"bar"))
417+
418+
# Now flush the provider and close it. The read task will be able to read the
419+
# alredy buffered data, but the write task will fail.
420+
close_task = asyncio.create_task(provider.close(flush=True))
421+
422+
# There is a timing issue to when a write will fail. If they're in the queue
423+
# before the close task, they may still make it through. Here the current
424+
# coroutine is suspended so that both the write task and close task have a
425+
# chance to check their conditions and set necessary state.
426+
await asyncio.sleep(0.1)
427+
428+
# Now we can start the read task. We can immediately await it because the close
429+
# task will complete in the background, which will then stop the iteration.
430+
result: list[bytes] = []
431+
await drain_provider(provider, result)
432+
433+
# Ensure that the close task is complete.
434+
await close_task
435+
436+
# The write will have been blocked by the close task, so the read task will
437+
# only see the initial data. The write task will raise an exception as the
438+
# provider closed before it could write its data.
439+
assert result == [b"foo"]
440+
with pytest.raises(SmithyException):
441+
await write_task

0 commit comments

Comments
 (0)