@@ -342,9 +342,9 @@ async def test_close_stops_writes() -> None:
342342 await provider .write (b"foo" )
343343
344344
345- async def test_close_deletes_buffered_data () -> None :
345+ async def test_close_without_flush_deletes_buffered_data () -> None :
346346 provider = AsyncBytesProvider (b"foo" )
347- await provider .close ()
347+ await provider .close (flush = False )
348348 result : list [bytes ] = []
349349 await drain_provider (provider , result )
350350 assert result == []
@@ -400,7 +400,7 @@ async def test_close_stops_queued_writes() -> None:
400400 write_task = asyncio .create_task (provider .write (b"bar" ))
401401
402402 # Now close the provider. The write task will raise an error.
403- await provider .close ()
403+ await provider .close (flush = False )
404404
405405 with pytest .raises (SmithyException ):
406406 await write_task
@@ -439,3 +439,29 @@ async def test_close_with_flush() -> None:
439439 assert result == [b"foo" ]
440440 with pytest .raises (SmithyException ):
441441 await write_task
442+
443+
444+ async def test_aexit_flushes () -> None :
445+ # Initialize a provider, keeping track of it in the top scope just to make
446+ # sure it doesn't get GC'd
447+ provider = AsyncBytesProvider ()
448+
449+ # Use the provider in a context manager. When this exits, it should flush
450+ # and close the provider.
451+ async with provider :
452+
453+ # Write some data to the provider.
454+ await provider .write (b"foo" )
455+
456+ # Start the task to read data from the provider and exit. Explictly do
457+ # not await it here because the exit function should pass priority while
458+ # it waits for the queue to drain.
459+ result : list [bytes ] = []
460+ drain_task = asyncio .create_task (drain_provider (provider , result ))
461+
462+ # The queue should have been read by this point.
463+ assert result == [b"foo" ]
464+
465+ # The draining task should be able to complete without errors. When next it
466+ # tries to get a chunk, the provider's iterator will exit.
467+ await drain_task
0 commit comments