|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | +import asyncio |
3 | 4 | from io import BytesIO |
4 | 5 | from typing import Self |
5 | 6 |
|
6 | 7 | import pytest |
7 | 8 |
|
8 | | -from smithy_core.aio.types import AsyncBytesReader, SeekableAsyncBytesReader |
| 9 | +from smithy_core.aio.types import ( |
| 10 | + AsyncBytesProvider, |
| 11 | + AsyncBytesReader, |
| 12 | + SeekableAsyncBytesReader, |
| 13 | +) |
| 14 | +from smithy_core.exceptions import SmithyException |
9 | 15 |
|
10 | 16 |
|
11 | 17 | class _AsyncIteratorWrapper: |
@@ -270,3 +276,166 @@ async def test_seekable_iter_chunks() -> None: |
270 | 276 |
|
271 | 277 | assert reader.tell() == 3 |
272 | 278 | assert result == b"foo" |
| 279 | + |
| 280 | + |
| 281 | +async def test_provider_requires_positive_max_chunks() -> None: |
| 282 | + with pytest.raises(ValueError): |
| 283 | + AsyncBytesProvider(max_buffered_chunks=-1) |
| 284 | + |
| 285 | + |
| 286 | +async def drain_provider(provider: AsyncBytesProvider, dest: list[bytes]) -> None: |
| 287 | + async for chunk in provider: |
| 288 | + dest.append(chunk) |
| 289 | + |
| 290 | + |
| 291 | +async def test_provider_reads_initial_data() -> None: |
| 292 | + provider = AsyncBytesProvider(intial_data=b"foo") |
| 293 | + result: list[bytes] = [] |
| 294 | + |
| 295 | + # Start the read task in the background. |
| 296 | + read_task = asyncio.create_task(drain_provider(provider, result)) |
| 297 | + |
| 298 | + # Wait for the buffer to drain. At that point all the data should |
| 299 | + # be read, but the read dask won't actually be complete yet |
| 300 | + # because it's still waiting on future data. |
| 301 | + await provider.flush() |
| 302 | + assert result == [b"foo"] |
| 303 | + assert not read_task.done() |
| 304 | + |
| 305 | + # Now actually close the provider, which will let the read task |
| 306 | + # complete. |
| 307 | + await provider.close() |
| 308 | + await read_task |
| 309 | + |
| 310 | + # The result should not have changed |
| 311 | + assert result == [b"foo"] |
| 312 | + |
| 313 | + |
| 314 | +async def test_provider_reads_written_data() -> None: |
| 315 | + provider = AsyncBytesProvider() |
| 316 | + result: list[bytes] = [] |
| 317 | + |
| 318 | + # Start the read task in the background. |
| 319 | + read_task = asyncio.create_task(drain_provider(provider, result)) |
| 320 | + await provider.write(b"foo") |
| 321 | + |
| 322 | + # Wait for the buffer to drain. At that point all the data should |
| 323 | + # be read, but the read dask won't actually be complete yet |
| 324 | + # because it's still waiting on future data. |
| 325 | + await provider.flush() |
| 326 | + assert result == [b"foo"] |
| 327 | + assert not read_task.done() |
| 328 | + |
| 329 | + # Now actually close the provider, which will let the read task |
| 330 | + # complete. |
| 331 | + await provider.close() |
| 332 | + await read_task |
| 333 | + |
| 334 | + # The result should not have changed |
| 335 | + assert result == [b"foo"] |
| 336 | + |
| 337 | + |
| 338 | +async def test_close_stops_writes() -> None: |
| 339 | + provider = AsyncBytesProvider() |
| 340 | + await provider.close() |
| 341 | + with pytest.raises(SmithyException): |
| 342 | + await provider.write(b"foo") |
| 343 | + |
| 344 | + |
| 345 | +async def test_close_deletes_buffered_data() -> None: |
| 346 | + provider = AsyncBytesProvider(b"foo") |
| 347 | + await provider.close() |
| 348 | + result: list[bytes] = [] |
| 349 | + await drain_provider(provider, result) |
| 350 | + assert result == [] |
| 351 | + |
| 352 | + # We weren't able to read data, which is what we want. But here we dig into |
| 353 | + # the internals to be sure that the buffer is clear and no data is haning |
| 354 | + # around. |
| 355 | + assert provider._data == [] # type: ignore |
| 356 | + |
| 357 | + |
| 358 | +async def test_only_max_chunks_buffered() -> None: |
| 359 | + # Initialize the provider with a max buffer of one and immediately have it |
| 360 | + # filled with an initial chunk. |
| 361 | + provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1) |
| 362 | + |
| 363 | + # Schedule a write task. Using create_task immediately enqueues it, though it |
| 364 | + # won't start executing until its turn in the loop. |
| 365 | + write_task = asyncio.create_task(provider.write(b"bar")) |
| 366 | + |
| 367 | + # Suspend the current coroutine so the write task can take over. It shouldn't |
| 368 | + # complete because it should be waiting on the buffer to drain. One tenth of |
| 369 | + # a second is way more than enough time for it to complete under normal |
| 370 | + # circumstances. |
| 371 | + await asyncio.sleep(0.1) |
| 372 | + assert not write_task.done() |
| 373 | + |
| 374 | + # Now begin the read task in the background. Since it's draining the buffer, the |
| 375 | + # write task will become unblocked. |
| 376 | + result: list[bytes] = [] |
| 377 | + read_task = asyncio.create_task(drain_provider(provider, result)) |
| 378 | + |
| 379 | + # The read task won't be done until we close the provider, but the write task |
| 380 | + # should be able to complete now. |
| 381 | + await write_task |
| 382 | + |
| 383 | + # The write task and read task don't necessarily complete at the same time, |
| 384 | + # so we wait until the buffer is empty here. |
| 385 | + await provider.flush() |
| 386 | + assert result == [b"foo", b"bar"] |
| 387 | + |
| 388 | + # Now we can close the provider and wait for the read task to end. |
| 389 | + await provider.close() |
| 390 | + await read_task |
| 391 | + |
| 392 | + |
| 393 | +async def test_close_stops_queued_writes() -> None: |
| 394 | + # Initialize the provider with a max buffer of one and immediately have it |
| 395 | + # filled with an initial chunk. |
| 396 | + provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1) |
| 397 | + |
| 398 | + # Schedule a write task. Using create_task immediately enqueues it, though it |
| 399 | + # can't complete until the buffer is free. |
| 400 | + write_task = asyncio.create_task(provider.write(b"bar")) |
| 401 | + |
| 402 | + # Now close the provider. The write task will raise an error. |
| 403 | + await provider.close() |
| 404 | + |
| 405 | + with pytest.raises(SmithyException): |
| 406 | + await write_task |
| 407 | + |
| 408 | + |
| 409 | +async def test_close_with_flush() -> None: |
| 410 | + # Initialize the provider with a max buffer of one and immediately have it |
| 411 | + # filled with an initial chunk. |
| 412 | + provider = AsyncBytesProvider(b"foo", max_buffered_chunks=1) |
| 413 | + |
| 414 | + # Schedule a write task. Using create_task immediately enqueues it, though it |
| 415 | + # can't complete until the buffer is free. |
| 416 | + write_task = asyncio.create_task(provider.write(b"bar")) |
| 417 | + |
| 418 | + # Now flush the provider and close it. The read task will be able to read the |
| 419 | + # alredy buffered data, but the write task will fail. |
| 420 | + close_task = asyncio.create_task(provider.close(flush=True)) |
| 421 | + |
| 422 | + # There is a timing issue to when a write will fail. If they're in the queue |
| 423 | + # before the close task, they may still make it through. Here the current |
| 424 | + # coroutine is suspended so that both the write task and close task have a |
| 425 | + # chance to check their conditions and set necessary state. |
| 426 | + await asyncio.sleep(0.1) |
| 427 | + |
| 428 | + # Now we can start the read task. We can immediately await it because the close |
| 429 | + # task will complete in the background, which will then stop the iteration. |
| 430 | + result: list[bytes] = [] |
| 431 | + await drain_provider(provider, result) |
| 432 | + |
| 433 | + # Ensure that the close task is complete. |
| 434 | + await close_task |
| 435 | + |
| 436 | + # The write will have been blocked by the close task, so the read task will |
| 437 | + # only see the initial data. The write task will raise an exception as the |
| 438 | + # provider closed before it could write its data. |
| 439 | + assert result == [b"foo"] |
| 440 | + with pytest.raises(SmithyException): |
| 441 | + await write_task |
0 commit comments