Skip to content

Commit b68df1b

Browse files
committed
Add tests
1 parent 1d4d782 commit b68df1b

File tree

1 file changed

+347
-7
lines changed

1 file changed

+347
-7
lines changed

packages/smithy-http/tests/unit/aio/test_crt.py

Lines changed: 347 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
# pyright: reportPrivateUsage=false
4+
import asyncio
5+
from collections.abc import AsyncIterator
36
from copy import deepcopy
47
from io import BytesIO
8+
from unittest.mock import AsyncMock, Mock, patch
59

10+
import pytest
11+
from awscrt import http as crt_http # type: ignore
612
from smithy_core import URI
7-
from smithy_http import Fields
13+
from smithy_core.aio.types import AsyncBytesReader
14+
from smithy_http import Field, Fields
815
from smithy_http.aio import HTTPRequest
9-
from smithy_http.aio.crt import AWSCRTHTTPClient
16+
from smithy_http.aio.crt import (
17+
AWSCRTHTTPClient,
18+
AWSCRTHTTPClientConfig,
19+
AWSCRTHTTPResponse,
20+
)
21+
from smithy_http.exceptions import SmithyHTTPError
1022

1123

1224
def test_deepcopy_client() -> None:
@@ -26,8 +38,336 @@ def test_client_marshal_request() -> None:
2638
body=BytesIO(),
2739
fields=Fields(),
2840
)
29-
crt_request = client._marshal_request(request) # type: ignore
30-
assert crt_request.headers.get("host") == "example.com" # type: ignore
31-
assert crt_request.headers.get("accept") == "*/*" # type: ignore
32-
assert crt_request.method == "GET" # type: ignore
33-
assert crt_request.path == "/path?key1=value1&key2=value2" # type: ignore
41+
crt_request = client._marshal_request(request)
42+
assert crt_request.headers.get("host") == "example.com"
43+
assert crt_request.headers.get("accept") == "*/*"
44+
assert crt_request.method == "GET"
45+
assert crt_request.path == "/path?key1=value1&key2=value2"
46+
47+
48+
async def test_body_generator_bytes() -> None:
49+
"""Test body generator with bytes input."""
50+
client = AWSCRTHTTPClient()
51+
body = b"Hello, World!"
52+
53+
chunks: list[bytes] = []
54+
async for chunk in client._create_body_generator(body):
55+
chunks.append(chunk)
56+
57+
assert chunks == [b"Hello, World!"]
58+
59+
60+
async def test_body_generator_bytearray() -> None:
61+
"""Test body generator with bytearray input (should convert to bytes)."""
62+
client = AWSCRTHTTPClient()
63+
body = bytearray(b"mutable data")
64+
65+
chunks: list[bytes] = []
66+
async for chunk in client._create_body_generator(body):
67+
chunks.append(chunk)
68+
69+
assert chunks == [b"mutable data"]
70+
assert all(isinstance(chunk, bytes) for chunk in chunks)
71+
72+
73+
async def test_body_generator_bytesio() -> None:
74+
"""Test body generator with BytesIO (sync reader)."""
75+
client = AWSCRTHTTPClient()
76+
body = BytesIO(b"data from BytesIO")
77+
78+
chunks: list[bytes] = []
79+
async for chunk in client._create_body_generator(body):
80+
chunks.append(chunk)
81+
82+
result = b"".join(chunks)
83+
assert result == b"data from BytesIO"
84+
85+
86+
async def test_body_generator_async_bytes_reader() -> None:
87+
"""Test body generator with AsyncBytesReader."""
88+
client = AWSCRTHTTPClient()
89+
body = AsyncBytesReader(b"async reader data")
90+
91+
chunks: list[bytes] = []
92+
async for chunk in client._create_body_generator(body):
93+
chunks.append(chunk)
94+
95+
result = b"".join(chunks)
96+
assert result == b"async reader data"
97+
98+
99+
async def test_body_generator_async_iterable() -> None:
100+
"""Test body generator with custom AsyncIterable."""
101+
102+
async def custom_generator() -> AsyncIterator[bytes]:
103+
yield b"chunk1"
104+
yield b"chunk2"
105+
yield b"chunk3"
106+
107+
client = AWSCRTHTTPClient()
108+
body = custom_generator()
109+
110+
chunks: list[bytes] = []
111+
async for chunk in client._create_body_generator(body):
112+
chunks.append(chunk)
113+
114+
assert chunks == [b"chunk1", b"chunk2", b"chunk3"]
115+
116+
117+
async def test_body_generator_async_iterable_with_bytearray() -> None:
118+
"""Test that AsyncIterable yielding bytearray converts to bytes."""
119+
120+
async def generator_with_bytearray() -> AsyncIterator[bytes | bytearray]:
121+
yield b"bytes chunk"
122+
yield bytearray(b"bytearray chunk")
123+
yield b"more bytes"
124+
125+
client = AWSCRTHTTPClient()
126+
body = generator_with_bytearray()
127+
128+
chunks: list[bytes] = []
129+
async for chunk in client._create_body_generator(body): # type: ignore
130+
chunks.append(chunk)
131+
132+
assert chunks == [b"bytes chunk", b"bytearray chunk", b"more bytes"]
133+
assert all(isinstance(chunk, bytes) for chunk in chunks)
134+
135+
136+
async def test_body_generator_async_byte_stream() -> None:
137+
"""Test body generator with AsyncByteStream (object with async read)."""
138+
139+
class CustomAsyncStream:
140+
def __init__(self, data: bytes):
141+
self._data = BytesIO(data)
142+
143+
async def read(self, size: int = -1) -> bytes:
144+
# Simulate async read
145+
await asyncio.sleep(0)
146+
return self._data.read(size)
147+
148+
client = AWSCRTHTTPClient()
149+
body = CustomAsyncStream(b"x" * 100000) # 100KB of data
150+
151+
chunks: list[bytes] = []
152+
async for chunk in client._create_body_generator(body):
153+
chunks.append(chunk)
154+
155+
# Should read in 64KB chunks
156+
result = b"".join(chunks)
157+
assert len(result) == 100000
158+
assert result == b"x" * 100000
159+
160+
161+
async def test_body_generator_empty_bytes() -> None:
162+
"""Test body generator with empty bytes."""
163+
client = AWSCRTHTTPClient()
164+
body = b""
165+
166+
chunks: list[bytes] = []
167+
async for chunk in client._create_body_generator(body):
168+
chunks.append(chunk)
169+
170+
assert chunks == [b""]
171+
172+
173+
async def test_build_connection_http() -> None:
174+
"""Test building HTTP connection."""
175+
client = AWSCRTHTTPClient()
176+
url = URI(scheme="http", host="example.com", port=8080)
177+
178+
with patch("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new") as mock_new:
179+
mock_connection = AsyncMock()
180+
mock_connection.version = crt_http.HttpVersion.Http1_1
181+
mock_connection.is_open = Mock(return_value=True)
182+
mock_new.return_value = mock_connection
183+
184+
connection = await client._build_new_connection(url)
185+
186+
assert connection is mock_connection
187+
mock_new.assert_called_once()
188+
call_kwargs = mock_new.call_args[1]
189+
assert call_kwargs["host_name"] == "example.com"
190+
assert call_kwargs["port"] == 8080
191+
assert call_kwargs["tls_connection_options"] is None
192+
193+
194+
async def test_build_connection_https() -> None:
195+
"""Test building HTTPS connection with TLS."""
196+
client = AWSCRTHTTPClient()
197+
url = URI(scheme="https", host="secure.example.com")
198+
199+
with patch("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new") as mock_new:
200+
mock_connection = AsyncMock()
201+
mock_connection.version = crt_http.HttpVersion.Http2
202+
mock_connection.is_open = Mock(return_value=True)
203+
mock_new.return_value = mock_connection
204+
205+
connection = await client._build_new_connection(url)
206+
207+
assert connection is mock_connection
208+
mock_new.assert_called_once()
209+
call_kwargs = mock_new.call_args[1]
210+
assert call_kwargs["host_name"] == "secure.example.com"
211+
assert call_kwargs["port"] == 443
212+
assert call_kwargs["tls_connection_options"] is not None
213+
214+
215+
async def test_build_connection_unsupported_scheme() -> None:
216+
"""Test that unsupported URL schemes raise error."""
217+
client = AWSCRTHTTPClient()
218+
url = URI(scheme="ftp", host="example.com")
219+
220+
with pytest.raises(SmithyHTTPError, match="does not support URL scheme ftp"):
221+
await client._build_new_connection(url)
222+
223+
224+
async def test_validate_connection_http2_required() -> None:
225+
"""Test connection validation when force_http_2 is enabled."""
226+
config = AWSCRTHTTPClientConfig(force_http_2=True)
227+
client = AWSCRTHTTPClient(client_config=config)
228+
229+
# Mock HTTP/1.1 connection
230+
mock_connection = AsyncMock()
231+
mock_connection.version = crt_http.HttpVersion.Http1_1
232+
mock_connection.close = AsyncMock()
233+
234+
with pytest.raises(SmithyHTTPError, match="HTTP/2 could not be negotiated"):
235+
await client._validate_connection(mock_connection)
236+
237+
mock_connection.close.assert_called_once()
238+
239+
240+
async def test_validate_connection_http2_success() -> None:
241+
"""Test connection validation succeeds with HTTP/2."""
242+
config = AWSCRTHTTPClientConfig(force_http_2=True)
243+
client = AWSCRTHTTPClient(client_config=config)
244+
245+
# Mock HTTP/2 connection
246+
mock_connection = AsyncMock()
247+
mock_connection.version = crt_http.HttpVersion.Http2
248+
249+
# Should not raise
250+
await client._validate_connection(mock_connection)
251+
252+
253+
async def test_connection_pooling() -> None:
254+
"""Test that connections are pooled and reused."""
255+
client = AWSCRTHTTPClient()
256+
url = URI(scheme="https", host="example.com")
257+
258+
# Mock connection
259+
mock_connection = AsyncMock()
260+
mock_connection.version = crt_http.HttpVersion.Http2
261+
# is_open() should be a regular method, not async
262+
mock_connection.is_open = Mock(return_value=True)
263+
264+
with patch("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new") as mock_new:
265+
mock_new.return_value = mock_connection
266+
267+
# First call should create new connection
268+
conn1 = await client._get_connection(url)
269+
assert mock_new.call_count == 1
270+
271+
# Second call should reuse connection
272+
conn2 = await client._get_connection(url)
273+
assert mock_new.call_count == 1 # Not called again
274+
assert conn1 is conn2
275+
276+
277+
async def test_connection_pooling_different_hosts() -> None:
278+
"""Test that different hosts get different connections."""
279+
client = AWSCRTHTTPClient()
280+
url1 = URI(scheme="https", host="example1.com")
281+
url2 = URI(scheme="https", host="example2.com")
282+
283+
# Create two distinct mock connections
284+
mock_conn1 = AsyncMock()
285+
mock_conn1.version = crt_http.HttpVersion.Http2
286+
mock_conn1.is_open = Mock(return_value=True)
287+
288+
mock_conn2 = AsyncMock()
289+
mock_conn2.version = crt_http.HttpVersion.Http2
290+
mock_conn2.is_open = Mock(return_value=True)
291+
292+
with patch("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new") as mock_new:
293+
mock_new.side_effect = [mock_conn1, mock_conn2]
294+
295+
conn1 = await client._get_connection(url1)
296+
conn2 = await client._get_connection(url2)
297+
298+
assert mock_new.call_count == 2
299+
assert conn1 is mock_conn1
300+
assert conn2 is mock_conn2
301+
assert conn1 is not conn2
302+
303+
304+
async def test_connection_pooling_closed_connection() -> None:
305+
"""Test that closed connections are replaced."""
306+
client = AWSCRTHTTPClient()
307+
url = URI(scheme="https", host="example.com")
308+
309+
mock_connection1 = AsyncMock()
310+
mock_connection1.version = crt_http.HttpVersion.Http2
311+
mock_connection1.is_open = Mock(return_value=False) # Closed
312+
313+
mock_connection2 = AsyncMock()
314+
mock_connection2.version = crt_http.HttpVersion.Http2
315+
mock_connection2.is_open = Mock(return_value=True)
316+
317+
with patch("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new") as mock_new:
318+
mock_new.side_effect = [mock_connection1, mock_connection2]
319+
320+
# First call
321+
conn1 = await client._get_connection(url)
322+
assert conn1 is mock_connection1
323+
324+
# Connection is now closed, should create new one
325+
conn2 = await client._get_connection(url)
326+
assert conn2 is mock_connection2
327+
assert mock_new.call_count == 2
328+
329+
330+
async def test_response_chunks() -> None:
331+
"""Test reading response body chunks."""
332+
mock_stream = AsyncMock()
333+
mock_stream.get_next_response_chunk.side_effect = [
334+
b"chunk1",
335+
b"chunk2",
336+
b"chunk3",
337+
b"", # End of stream
338+
]
339+
340+
response = AWSCRTHTTPResponse(status=200, fields=Fields(), stream=mock_stream)
341+
342+
chunks: list[bytes] = []
343+
async for chunk in response.chunks():
344+
chunks.append(chunk)
345+
346+
assert chunks == [b"chunk1", b"chunk2", b"chunk3"]
347+
348+
349+
async def test_response_body_property() -> None:
350+
"""Test that body property returns chunks."""
351+
mock_stream = AsyncMock()
352+
mock_stream.get_next_response_chunk.side_effect = [b"data", b""]
353+
354+
response = AWSCRTHTTPResponse(status=200, fields=Fields(), stream=mock_stream)
355+
356+
chunks: list[bytes] = []
357+
async for chunk in response.body:
358+
chunks.append(chunk)
359+
360+
assert chunks == [b"data"]
361+
362+
363+
def test_response_properties() -> None:
364+
"""Test response property accessors."""
365+
fields = Fields()
366+
fields.set_field(Field(name="content-type", values=["application/json"]))
367+
368+
mock_stream = Mock()
369+
response = AWSCRTHTTPResponse(status=404, fields=fields, stream=mock_stream)
370+
371+
assert response.status == 404
372+
assert response.fields == fields
373+
assert response.reason is None

0 commit comments

Comments
 (0)