|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
| 3 | +import asyncio |
3 | 4 | from copy import deepcopy |
4 | 5 | from io import BytesIO |
| 6 | +from unittest.mock import Mock |
| 7 | +from concurrent.futures import Future as ConcurrentFuture |
5 | 8 |
|
6 | 9 | import pytest |
| 10 | +from awscrt.http import HttpClientStream # type: ignore |
7 | 11 |
|
8 | 12 | from smithy_core import URI |
9 | 13 | from smithy_http import Fields |
10 | 14 | from smithy_http.aio import HTTPRequest |
11 | | -from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream |
| 15 | +from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream, CRTResponseBody |
12 | 16 |
|
13 | 17 |
|
14 | 18 | def test_deepcopy_client() -> None: |
@@ -136,3 +140,77 @@ def test_end_stream() -> None: |
136 | 140 | assert not stream.closed |
137 | 141 | assert stream.read() == b"foo" |
138 | 142 | assert stream.closed |
| 143 | + |
| 144 | + |
| 145 | +async def test_response_body_completed_stream() -> None: |
| 146 | + completion_future = ConcurrentFuture[int]() |
| 147 | + mock_stream = Mock(spec=HttpClientStream) |
| 148 | + mock_stream.completion_future = completion_future |
| 149 | + |
| 150 | + response_body = CRTResponseBody() |
| 151 | + response_body.set_stream(mock_stream) |
| 152 | + completion_future.set_result(200) |
| 153 | + |
| 154 | + assert await response_body.next() == b"" |
| 155 | + |
| 156 | + |
| 157 | +async def test_response_body_empty_stream() -> None: |
| 158 | + completion_future = ConcurrentFuture[int]() |
| 159 | + mock_stream = Mock(spec=HttpClientStream) |
| 160 | + mock_stream.completion_future = completion_future |
| 161 | + |
| 162 | + response_body = CRTResponseBody() |
| 163 | + response_body.set_stream(mock_stream) |
| 164 | + |
| 165 | + read_task = asyncio.create_task(response_body.next()) |
| 166 | + |
| 167 | + # Sleep briefly so the read task gets priority. It should |
| 168 | + # add a chunk future and then await it. |
| 169 | + await asyncio.sleep(0.01) |
| 170 | + |
| 171 | + assert len(response_body._chunk_futures) == 1 # type: ignore |
| 172 | + response_body.on_body(b"foo") |
| 173 | + assert await read_task == b"foo" |
| 174 | + |
| 175 | + |
| 176 | +async def test_response_body_stream_completion_clears_buffer() -> None: |
| 177 | + completion_future = ConcurrentFuture[int]() |
| 178 | + mock_stream = Mock(spec=HttpClientStream) |
| 179 | + mock_stream.completion_future = completion_future |
| 180 | + |
| 181 | + response_body = CRTResponseBody() |
| 182 | + response_body.set_stream(mock_stream) |
| 183 | + |
| 184 | + read_tasks = ( |
| 185 | + asyncio.create_task(response_body.next()), |
| 186 | + asyncio.create_task(response_body.next()), |
| 187 | + asyncio.create_task(response_body.next()), |
| 188 | + asyncio.create_task(response_body.next()), |
| 189 | + ) |
| 190 | + |
| 191 | + # Sleep briefly so the read tasks gets priority. It should |
| 192 | + # add a chunk future and then await it. |
| 193 | + await asyncio.sleep(0.01) |
| 194 | + |
| 195 | + assert len(response_body._chunk_futures) == 4 # type: ignore |
| 196 | + completion_future.set_result(200) |
| 197 | + await asyncio.sleep(0.01) |
| 198 | + |
| 199 | + # Tasks should have been drained |
| 200 | + assert len(response_body._chunk_futures) == 0 # type: ignore |
| 201 | + |
| 202 | + # Tasks should still be awaited, and should all return empty |
| 203 | + results = asyncio.gather(*read_tasks) |
| 204 | + assert results.result() == [b"", b"", b"", b""] |
| 205 | + |
| 206 | + |
| 207 | +async def test_response_body_non_empty_stream() -> None: |
| 208 | + completion_future = ConcurrentFuture[int]() |
| 209 | + mock_stream = Mock(spec=HttpClientStream) |
| 210 | + mock_stream.completion_future = completion_future |
| 211 | + |
| 212 | + response_body = CRTResponseBody() |
| 213 | + response_body.set_stream(mock_stream) |
| 214 | + response_body.on_body(b"foo") |
| 215 | + |
| 216 | + assert await response_body.next() == b"foo" |
0 commit comments