Skip to content

Commit 41f0ca8

Browse files
Wrap completion future in crt response body
1 parent fcb036e commit 41f0ca8

File tree

2 files changed

+87
-5
lines changed

2 files changed

+87
-5
lines changed

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# pyright: reportMissingTypeStubs=false,reportUnknownMemberType=false
44
# flake8: noqa: F811
55
import asyncio
6+
from asyncio import Future as AsyncFuture
67
from concurrent.futures import Future as ConcurrentFuture
78
from collections import deque
89
from collections.abc import AsyncGenerator, AsyncIterable
@@ -105,6 +106,7 @@ def __repr__(self) -> str:
105106
class CRTResponseBody:
106107
def __init__(self) -> None:
107108
self._stream: "crt_http.HttpClientStream | None" = None
109+
self._completion_future: AsyncFuture[int] | None = None
108110
self._chunk_futures: deque[ConcurrentFuture[bytes]] = deque()
109111

110112
# deque is thread safe and the crt is only going to be writing
@@ -117,7 +119,9 @@ def set_stream(self, stream: "crt_http.HttpClientStream") -> None:
117119
if self._stream is not None:
118120
raise SmithyHTTPException("Stream already set on AWSCRTHTTPResponse object")
119121
self._stream = stream
120-
self._stream.completion_future.add_done_callback(self._on_complete)
122+
concurrent_future: ConcurrentFuture[int] = stream.completion_future
123+
self._completion_future = asyncio.wrap_future(concurrent_future)
124+
self._completion_future.add_done_callback(self._on_complete)
121125
self._stream.activate()
122126

123127
def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback
@@ -129,21 +133,21 @@ def on_body(self, chunk: bytes, **kwargs: Any) -> None: # pragma: crt-callback
129133
self._received_chunks.append(chunk)
130134

131135
async def next(self) -> bytes:
132-
if self._stream is None:
136+
if self._completion_future is None:
133137
raise SmithyHTTPException("Stream not set")
134138

135139
# TODO: update backpressure window once CRT supports it
136140
if self._received_chunks:
137141
return self._received_chunks.popleft()
138-
elif self._stream.completion_future.done():
142+
elif self._completion_future.done():
139143
return b""
140144
else:
141145
future = ConcurrentFuture[bytes]()
142146
self._chunk_futures.append(future)
143147
return await asyncio.wrap_future(future)
144148

145149
def _on_complete(
146-
self, completion_future: ConcurrentFuture[int]
150+
self, completion_future: AsyncFuture[int]
147151
) -> None: # pragma: crt-callback
148152
for future in self._chunk_futures:
149153
future.set_result(b"")

packages/smithy-http/tests/unit/aio/test_crt.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
import asyncio
34
from copy import deepcopy
45
from io import BytesIO
6+
from unittest.mock import Mock
7+
from concurrent.futures import Future as ConcurrentFuture
58

69
import pytest
10+
from awscrt.http import HttpClientStream # type: ignore
711

812
from smithy_core import URI
913
from smithy_http import Fields
1014
from smithy_http.aio import HTTPRequest
11-
from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream
15+
from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream, CRTResponseBody
1216

1317

1418
def test_deepcopy_client() -> None:
@@ -136,3 +140,77 @@ def test_end_stream() -> None:
136140
assert not stream.closed
137141
assert stream.read() == b"foo"
138142
assert stream.closed
143+
144+
145+
async def test_response_body_completed_stream() -> None:
146+
completion_future = ConcurrentFuture[int]()
147+
mock_stream = Mock(spec=HttpClientStream)
148+
mock_stream.completion_future = completion_future
149+
150+
response_body = CRTResponseBody()
151+
response_body.set_stream(mock_stream)
152+
completion_future.set_result(200)
153+
154+
assert await response_body.next() == b""
155+
156+
157+
async def test_response_body_empty_stream() -> None:
158+
completion_future = ConcurrentFuture[int]()
159+
mock_stream = Mock(spec=HttpClientStream)
160+
mock_stream.completion_future = completion_future
161+
162+
response_body = CRTResponseBody()
163+
response_body.set_stream(mock_stream)
164+
165+
read_task = asyncio.create_task(response_body.next())
166+
167+
# Sleep briefly so the read task gets priority. It should
168+
# add a chunk future and then await it.
169+
await asyncio.sleep(0.01)
170+
171+
assert len(response_body._chunk_futures) == 1 # type: ignore
172+
response_body.on_body(b"foo")
173+
assert await read_task == b"foo"
174+
175+
176+
async def test_response_body_stream_completion_clears_buffer() -> None:
177+
completion_future = ConcurrentFuture[int]()
178+
mock_stream = Mock(spec=HttpClientStream)
179+
mock_stream.completion_future = completion_future
180+
181+
response_body = CRTResponseBody()
182+
response_body.set_stream(mock_stream)
183+
184+
read_tasks = (
185+
asyncio.create_task(response_body.next()),
186+
asyncio.create_task(response_body.next()),
187+
asyncio.create_task(response_body.next()),
188+
asyncio.create_task(response_body.next()),
189+
)
190+
191+
# Sleep briefly so the read tasks gets priority. It should
192+
# add a chunk future and then await it.
193+
await asyncio.sleep(0.01)
194+
195+
assert len(response_body._chunk_futures) == 4 # type: ignore
196+
completion_future.set_result(200)
197+
await asyncio.sleep(0.01)
198+
199+
# Tasks should have been drained
200+
assert len(response_body._chunk_futures) == 0 # type: ignore
201+
202+
# Tasks should still be awaited, and should all return empty
203+
results = asyncio.gather(*read_tasks)
204+
assert results.result() == [b"", b"", b"", b""]
205+
206+
207+
async def test_response_body_non_empty_stream() -> None:
208+
completion_future = ConcurrentFuture[int]()
209+
mock_stream = Mock(spec=HttpClientStream)
210+
mock_stream.completion_future = completion_future
211+
212+
response_body = CRTResponseBody()
213+
response_body.set_stream(mock_stream)
214+
response_body.on_body(b"foo")
215+
216+
assert await response_body.next() == b"foo"

0 commit comments

Comments
 (0)