Skip to content

Commit 8098c6a

Browse files
committed
fix type checks and remove outdated tests
1 parent 9abf092 commit 8098c6a

File tree

2 files changed

+24
-209
lines changed

2 files changed

+24
-209
lines changed

packages/smithy-http/src/smithy_http/aio/crt.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,16 @@
1414
from awscrt import http as crt_http_base
1515
from awscrt import io as crt_io
1616
from awscrt.aio.http import (
17-
AIOHttp2ClientConnection,
18-
AIOHttp2ClientStream,
1917
AIOHttpClientConnectionUnified,
20-
AIOHttpClientStream,
18+
AIOHttpClientStreamUnified,
2119
)
2220

2321
try:
2422
from awscrt import http as crt_http_base
2523
from awscrt import io as crt_io
2624
from awscrt.aio.http import (
27-
AIOHttp2ClientConnection,
28-
AIOHttp2ClientStream,
2925
AIOHttpClientConnectionUnified,
30-
AIOHttpClientStream,
26+
AIOHttpClientStreamUnified,
3127
)
3228

3329
HAS_CRT = True
@@ -70,7 +66,7 @@ def __init__(
7066
*,
7167
status: int,
7268
fields: Fields,
73-
stream: "AIOHttpClientStream | AIOHttp2ClientStream",
69+
stream: "AIOHttpClientStreamUnified",
7470
) -> None:
7571
_assert_crt()
7672
self._status = status
@@ -112,7 +108,7 @@ def __repr__(self) -> str:
112108

113109

114110
ConnectionPoolKey = tuple[str, str, int | None]
115-
ConnectionPoolDict = dict[ConnectionPoolKey, "AIOHttp2ClientConnection"]
111+
ConnectionPoolDict = dict[ConnectionPoolKey, "AIOHttpClientConnectionUnified"]
116112

117113

118114
class AWSCRTHTTPClientConfig(http_interfaces.HTTPClientConfiguration):
@@ -168,7 +164,7 @@ async def send(
168164
return await self._await_response(crt_stream)
169165

170166
async def _await_response(
171-
self, stream: "AIOHttpClientStream | AIOHttp2ClientStream"
167+
self, stream: "AIOHttpClientStreamUnified"
172168
) -> AWSCRTHTTPResponse:
173169
status_code = await stream.get_response_status_code()
174170
headers = await stream.get_response_headers()
@@ -190,15 +186,15 @@ async def _await_response(
190186

191187
async def _create_connection(
192188
self, url: core_interfaces.URI
193-
) -> "AIOHttp2ClientConnection":
189+
) -> "AIOHttpClientConnectionUnified":
194190
"""Builds and validates connection to ``url``"""
195191
connection = await self._build_new_connection(url)
196192
await self._validate_connection(connection)
197193
return connection
198194

199195
async def _get_connection(
200196
self, url: core_interfaces.URI
201-
) -> "AIOHttp2ClientConnection":
197+
) -> "AIOHttpClientConnectionUnified":
202198
# TODO: Use CRT connection pooling instead of this basic kind
203199
connection_key = (url.scheme, url.host, url.port)
204200
connection = self._connections.get(connection_key)
@@ -212,7 +208,7 @@ async def _get_connection(
212208

213209
async def _build_new_connection(
214210
self, url: core_interfaces.URI
215-
) -> "AIOHttp2ClientConnection":
211+
) -> "AIOHttpClientConnectionUnified":
216212
if url.scheme == "http":
217213
port = self._HTTP_PORT
218214
tls_connection_options = None
@@ -229,7 +225,7 @@ async def _build_new_connection(
229225
if url.port is not None:
230226
port = url.port
231227

232-
return await AIOHttp2ClientConnection.new(
228+
return await AIOHttpClientConnectionUnified.new(
233229
bootstrap=self._client_bootstrap,
234230
host_name=url.host,
235231
port=port,
@@ -297,11 +293,19 @@ async def _create_body_generator(
297293
elif isinstance(body, bytearray):
298294
# Convert bytearray to bytes
299295
yield bytes(body)
296+
elif isinstance(body, AsyncIterable):
297+
# Already async iterable, just yield from it.
298+
# Check this before AsyncByteStream since AsyncBytesReader implements both.
299+
async for chunk in body:
300+
if isinstance(chunk, bytearray):
301+
yield bytes(chunk)
302+
else:
303+
yield chunk
300304
elif iscoroutinefunction(getattr(body, "read", None)) and isinstance(
301305
body, # type: ignore[reportGeneralTypeIssues]
302306
core_aio_interfaces.AsyncByteStream, # type: ignore[reportGeneralTypeIssues]
303307
):
304-
# AsyncByteStream has async read method
308+
# AsyncByteStream has async read method but is not iterable
305309
while True:
306310
chunk = await body.read(65536) # Read in 64KB chunks
307311
if not chunk:
@@ -310,15 +314,8 @@ async def _create_body_generator(
310314
yield bytes(chunk)
311315
else:
312316
yield chunk
313-
elif isinstance(body, AsyncIterable):
314-
# Already async iterable, just yield from it
315-
async for chunk in body:
316-
if isinstance(chunk, bytearray):
317-
yield bytes(chunk)
318-
else:
319-
yield chunk
320317
else:
321-
# Assume it's a BytesReader, wrap it in AsyncBytesReader
318+
# Assume it's a sync BytesReader, wrap it in AsyncBytesReader
322319
async_reader = AsyncBytesReader(body)
323320
async for chunk in async_reader:
324321
if isinstance(chunk, bytearray):
Lines changed: 5 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
import asyncio
4-
from concurrent.futures import Future as ConcurrentFuture
53
from copy import deepcopy
64
from io import BytesIO
7-
from unittest.mock import Mock
85

9-
import pytest
10-
from awscrt.http import HttpClientStream # type: ignore
116
from smithy_core import URI
127
from smithy_http import Fields
138
from smithy_http.aio import HTTPRequest
14-
from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream, CRTResponseBody
9+
from smithy_http.aio.crt import AWSCRTHTTPClient
1510

1611

1712
def test_deepcopy_client() -> None:
13+
"""Test that AWSCRTHTTPClient can be deep copied."""
1814
client = AWSCRTHTTPClient()
1915
deepcopy(client)
2016

2117

22-
async def test_client_marshal_request() -> None:
18+
def test_client_marshal_request() -> None:
19+
"""Test that HTTPRequest is correctly marshaled to CRT HttpRequest."""
2320
client = AWSCRTHTTPClient()
2421
request = HTTPRequest(
2522
method="GET",
@@ -29,187 +26,8 @@ async def test_client_marshal_request() -> None:
2926
body=BytesIO(),
3027
fields=Fields(),
3128
)
32-
crt_request, _ = await client._marshal_request(request) # type: ignore
29+
crt_request = client._marshal_request(request) # type: ignore
3330
assert crt_request.headers.get("host") == "example.com" # type: ignore
3431
assert crt_request.headers.get("accept") == "*/*" # type: ignore
3532
assert crt_request.method == "GET" # type: ignore
3633
assert crt_request.path == "/path?key1=value1&key2=value2" # type: ignore
37-
38-
39-
def test_stream_write() -> None:
40-
stream = BufferableByteStream()
41-
stream.write(b"foo")
42-
assert stream.read() == b"foo"
43-
44-
45-
def test_stream_reads_individual_chunks() -> None:
46-
stream = BufferableByteStream()
47-
stream.write(b"foo")
48-
stream.write(b"bar")
49-
assert stream.read() == b"foo"
50-
assert stream.read() == b"bar"
51-
52-
53-
def test_stream_empty_read() -> None:
54-
stream = BufferableByteStream()
55-
with pytest.raises(BlockingIOError):
56-
stream.read()
57-
58-
59-
def test_stream_partial_chunk_read() -> None:
60-
stream = BufferableByteStream()
61-
stream.write(b"foobar")
62-
assert stream.read(3) == b"foo"
63-
assert stream.read() == b"bar"
64-
65-
66-
def test_stream_write_empty_bytes() -> None:
67-
stream = BufferableByteStream()
68-
stream.write(b"")
69-
stream.write(b"foo")
70-
stream.write(b"")
71-
assert stream.read() == b"foo"
72-
73-
74-
def test_stream_write_non_bytes() -> None:
75-
stream = BufferableByteStream()
76-
with pytest.raises(ValueError):
77-
stream.write(memoryview(b"foo"))
78-
79-
80-
def test_closed_stream_write() -> None:
81-
stream = BufferableByteStream()
82-
stream.close()
83-
with pytest.raises(IOError):
84-
stream.write(b"foo")
85-
86-
87-
def test_closed_stream_read() -> None:
88-
stream = BufferableByteStream()
89-
stream.write(b"foo")
90-
stream.close()
91-
assert stream.read() == b""
92-
93-
94-
def test_done_stream_read() -> None:
95-
stream = BufferableByteStream()
96-
stream.write(b"foo")
97-
stream.end_stream()
98-
assert stream.read() == b"foo"
99-
assert stream.read() == b""
100-
101-
102-
def test_end_empty_stream() -> None:
103-
stream = BufferableByteStream()
104-
stream.end_stream()
105-
assert stream.read() == b""
106-
107-
108-
def test_stream_read1() -> None:
109-
stream = BufferableByteStream()
110-
stream.write(b"foo")
111-
stream.write(b"bar")
112-
assert stream.read1() == b"foo"
113-
assert stream.read1() == b"bar"
114-
with pytest.raises(BlockingIOError):
115-
stream.read()
116-
117-
118-
def test_stream_readinto_memoryview() -> None:
119-
buffer = memoryview(bytearray(b" "))
120-
stream = BufferableByteStream()
121-
stream.write(b"foobar")
122-
stream.readinto(buffer)
123-
assert bytes(buffer) == b"foo"
124-
125-
126-
def test_stream_readinto_bytearray() -> None:
127-
buffer = bytearray(b" ")
128-
stream = BufferableByteStream()
129-
stream.write(b"foobar")
130-
stream.readinto(buffer)
131-
assert bytes(buffer) == b"foo"
132-
133-
134-
def test_end_stream() -> None:
135-
stream = BufferableByteStream()
136-
stream.write(b"foo")
137-
stream.end_stream()
138-
139-
assert not stream.closed
140-
assert stream.read() == b"foo"
141-
assert stream.closed
142-
143-
144-
async def test_response_body_completed_stream() -> None:
145-
completion_future = ConcurrentFuture[int]()
146-
mock_stream = Mock(spec=HttpClientStream)
147-
mock_stream.completion_future = completion_future
148-
149-
response_body = CRTResponseBody()
150-
response_body.set_stream(mock_stream)
151-
completion_future.set_result(200)
152-
153-
assert await response_body.next() == b""
154-
155-
156-
async def test_response_body_empty_stream() -> None:
157-
completion_future = ConcurrentFuture[int]()
158-
mock_stream = Mock(spec=HttpClientStream)
159-
mock_stream.completion_future = completion_future
160-
161-
response_body = CRTResponseBody()
162-
response_body.set_stream(mock_stream)
163-
164-
read_task = asyncio.create_task(response_body.next())
165-
166-
# Sleep briefly so the read task gets priority. It should
167-
# add a chunk future and then await it.
168-
await asyncio.sleep(0.01)
169-
170-
assert len(response_body._chunk_futures) == 1 # type: ignore
171-
response_body.on_body(b"foo")
172-
assert await read_task == b"foo"
173-
174-
175-
async def test_response_body_stream_completion_clears_buffer() -> None:
176-
completion_future = ConcurrentFuture[int]()
177-
mock_stream = Mock(spec=HttpClientStream)
178-
mock_stream.completion_future = completion_future
179-
180-
response_body = CRTResponseBody()
181-
response_body.set_stream(mock_stream)
182-
183-
read_tasks = (
184-
asyncio.create_task(response_body.next()),
185-
asyncio.create_task(response_body.next()),
186-
asyncio.create_task(response_body.next()),
187-
asyncio.create_task(response_body.next()),
188-
)
189-
190-
# Sleep briefly so the read tasks gets priority. It should
191-
# add a chunk future and then await it.
192-
await asyncio.sleep(0.01)
193-
194-
assert len(response_body._chunk_futures) == 4 # type: ignore
195-
completion_future.set_result(200)
196-
await asyncio.sleep(0.01)
197-
198-
# Tasks should have been drained
199-
assert len(response_body._chunk_futures) == 0 # type: ignore
200-
201-
# Tasks should still be awaited, and should all return empty
202-
results = asyncio.gather(*read_tasks)
203-
assert results.result() == [b"", b"", b"", b""]
204-
205-
206-
async def test_response_body_non_empty_stream() -> None:
207-
completion_future = ConcurrentFuture[int]()
208-
mock_stream = Mock(spec=HttpClientStream)
209-
mock_stream.completion_future = completion_future
210-
211-
response_body = CRTResponseBody()
212-
response_body.set_stream(mock_stream)
213-
response_body.on_body(b"foo")
214-
215-
assert await response_body.next() == b"foo"

0 commit comments

Comments
 (0)