11# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
3+ # pyright: reportPrivateUsage=false
4+ import asyncio
5+ from collections .abc import AsyncIterator
36from copy import deepcopy
47from io import BytesIO
8+ from unittest .mock import AsyncMock , Mock , patch
59
10+ import pytest
11+ from awscrt import http as crt_http # type: ignore
612from smithy_core import URI
7- from smithy_http import Fields
13+ from smithy_core .aio .types import AsyncBytesReader
14+ from smithy_http import Field , Fields
815from smithy_http .aio import HTTPRequest
9- from smithy_http .aio .crt import AWSCRTHTTPClient
16+ from smithy_http .aio .crt import (
17+ AWSCRTHTTPClient ,
18+ AWSCRTHTTPClientConfig ,
19+ AWSCRTHTTPResponse ,
20+ )
21+ from smithy_http .exceptions import SmithyHTTPError
1022
1123
1224def test_deepcopy_client () -> None :
@@ -26,8 +38,336 @@ def test_client_marshal_request() -> None:
2638 body = BytesIO (),
2739 fields = Fields (),
2840 )
29- crt_request = client ._marshal_request (request ) # type: ignore
30- assert crt_request .headers .get ("host" ) == "example.com" # type: ignore
31- assert crt_request .headers .get ("accept" ) == "*/*" # type: ignore
32- assert crt_request .method == "GET" # type: ignore
33- assert crt_request .path == "/path?key1=value1&key2=value2" # type: ignore
41+ crt_request = client ._marshal_request (request )
42+ assert crt_request .headers .get ("host" ) == "example.com"
43+ assert crt_request .headers .get ("accept" ) == "*/*"
44+ assert crt_request .method == "GET"
45+ assert crt_request .path == "/path?key1=value1&key2=value2"
46+
47+
48+ async def test_body_generator_bytes () -> None :
49+ """Test body generator with bytes input."""
50+ client = AWSCRTHTTPClient ()
51+ body = b"Hello, World!"
52+
53+ chunks : list [bytes ] = []
54+ async for chunk in client ._create_body_generator (body ):
55+ chunks .append (chunk )
56+
57+ assert chunks == [b"Hello, World!" ]
58+
59+
60+ async def test_body_generator_bytearray () -> None :
61+ """Test body generator with bytearray input (should convert to bytes)."""
62+ client = AWSCRTHTTPClient ()
63+ body = bytearray (b"mutable data" )
64+
65+ chunks : list [bytes ] = []
66+ async for chunk in client ._create_body_generator (body ):
67+ chunks .append (chunk )
68+
69+ assert chunks == [b"mutable data" ]
70+ assert all (isinstance (chunk , bytes ) for chunk in chunks )
71+
72+
73+ async def test_body_generator_bytesio () -> None :
74+ """Test body generator with BytesIO (sync reader)."""
75+ client = AWSCRTHTTPClient ()
76+ body = BytesIO (b"data from BytesIO" )
77+
78+ chunks : list [bytes ] = []
79+ async for chunk in client ._create_body_generator (body ):
80+ chunks .append (chunk )
81+
82+ result = b"" .join (chunks )
83+ assert result == b"data from BytesIO"
84+
85+
86+ async def test_body_generator_async_bytes_reader () -> None :
87+ """Test body generator with AsyncBytesReader."""
88+ client = AWSCRTHTTPClient ()
89+ body = AsyncBytesReader (b"async reader data" )
90+
91+ chunks : list [bytes ] = []
92+ async for chunk in client ._create_body_generator (body ):
93+ chunks .append (chunk )
94+
95+ result = b"" .join (chunks )
96+ assert result == b"async reader data"
97+
98+
99+ async def test_body_generator_async_iterable () -> None :
100+ """Test body generator with custom AsyncIterable."""
101+
102+ async def custom_generator () -> AsyncIterator [bytes ]:
103+ yield b"chunk1"
104+ yield b"chunk2"
105+ yield b"chunk3"
106+
107+ client = AWSCRTHTTPClient ()
108+ body = custom_generator ()
109+
110+ chunks : list [bytes ] = []
111+ async for chunk in client ._create_body_generator (body ):
112+ chunks .append (chunk )
113+
114+ assert chunks == [b"chunk1" , b"chunk2" , b"chunk3" ]
115+
116+
117+ async def test_body_generator_async_iterable_with_bytearray () -> None :
118+ """Test that AsyncIterable yielding bytearray converts to bytes."""
119+
120+ async def generator_with_bytearray () -> AsyncIterator [bytes | bytearray ]:
121+ yield b"bytes chunk"
122+ yield bytearray (b"bytearray chunk" )
123+ yield b"more bytes"
124+
125+ client = AWSCRTHTTPClient ()
126+ body = generator_with_bytearray ()
127+
128+ chunks : list [bytes ] = []
129+ async for chunk in client ._create_body_generator (body ): # type: ignore
130+ chunks .append (chunk )
131+
132+ assert chunks == [b"bytes chunk" , b"bytearray chunk" , b"more bytes" ]
133+ assert all (isinstance (chunk , bytes ) for chunk in chunks )
134+
135+
136+ async def test_body_generator_async_byte_stream () -> None :
137+ """Test body generator with AsyncByteStream (object with async read)."""
138+
139+ class CustomAsyncStream :
140+ def __init__ (self , data : bytes ):
141+ self ._data = BytesIO (data )
142+
143+ async def read (self , size : int = - 1 ) -> bytes :
144+ # Simulate async read
145+ await asyncio .sleep (0 )
146+ return self ._data .read (size )
147+
148+ client = AWSCRTHTTPClient ()
149+ body = CustomAsyncStream (b"x" * 100000 ) # 100KB of data
150+
151+ chunks : list [bytes ] = []
152+ async for chunk in client ._create_body_generator (body ):
153+ chunks .append (chunk )
154+
155+ # Should read in 64KB chunks
156+ result = b"" .join (chunks )
157+ assert len (result ) == 100000
158+ assert result == b"x" * 100000
159+
160+
161+ async def test_body_generator_empty_bytes () -> None :
162+ """Test body generator with empty bytes."""
163+ client = AWSCRTHTTPClient ()
164+ body = b""
165+
166+ chunks : list [bytes ] = []
167+ async for chunk in client ._create_body_generator (body ):
168+ chunks .append (chunk )
169+
170+ assert chunks == [b"" ]
171+
172+
173+ async def test_build_connection_http () -> None :
174+ """Test building HTTP connection."""
175+ client = AWSCRTHTTPClient ()
176+ url = URI (scheme = "http" , host = "example.com" , port = 8080 )
177+
178+ with patch ("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new" ) as mock_new :
179+ mock_connection = AsyncMock ()
180+ mock_connection .version = crt_http .HttpVersion .Http1_1
181+ mock_connection .is_open = Mock (return_value = True )
182+ mock_new .return_value = mock_connection
183+
184+ connection = await client ._build_new_connection (url )
185+
186+ assert connection is mock_connection
187+ mock_new .assert_called_once ()
188+ call_kwargs = mock_new .call_args [1 ]
189+ assert call_kwargs ["host_name" ] == "example.com"
190+ assert call_kwargs ["port" ] == 8080
191+ assert call_kwargs ["tls_connection_options" ] is None
192+
193+
194+ async def test_build_connection_https () -> None :
195+ """Test building HTTPS connection with TLS."""
196+ client = AWSCRTHTTPClient ()
197+ url = URI (scheme = "https" , host = "secure.example.com" )
198+
199+ with patch ("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new" ) as mock_new :
200+ mock_connection = AsyncMock ()
201+ mock_connection .version = crt_http .HttpVersion .Http2
202+ mock_connection .is_open = Mock (return_value = True )
203+ mock_new .return_value = mock_connection
204+
205+ connection = await client ._build_new_connection (url )
206+
207+ assert connection is mock_connection
208+ mock_new .assert_called_once ()
209+ call_kwargs = mock_new .call_args [1 ]
210+ assert call_kwargs ["host_name" ] == "secure.example.com"
211+ assert call_kwargs ["port" ] == 443
212+ assert call_kwargs ["tls_connection_options" ] is not None
213+
214+
215+ async def test_build_connection_unsupported_scheme () -> None :
216+ """Test that unsupported URL schemes raise error."""
217+ client = AWSCRTHTTPClient ()
218+ url = URI (scheme = "ftp" , host = "example.com" )
219+
220+ with pytest .raises (SmithyHTTPError , match = "does not support URL scheme ftp" ):
221+ await client ._build_new_connection (url )
222+
223+
224+ async def test_validate_connection_http2_required () -> None :
225+ """Test connection validation when force_http_2 is enabled."""
226+ config = AWSCRTHTTPClientConfig (force_http_2 = True )
227+ client = AWSCRTHTTPClient (client_config = config )
228+
229+ # Mock HTTP/1.1 connection
230+ mock_connection = AsyncMock ()
231+ mock_connection .version = crt_http .HttpVersion .Http1_1
232+ mock_connection .close = AsyncMock ()
233+
234+ with pytest .raises (SmithyHTTPError , match = "HTTP/2 could not be negotiated" ):
235+ await client ._validate_connection (mock_connection )
236+
237+ mock_connection .close .assert_called_once ()
238+
239+
240+ async def test_validate_connection_http2_success () -> None :
241+ """Test connection validation succeeds with HTTP/2."""
242+ config = AWSCRTHTTPClientConfig (force_http_2 = True )
243+ client = AWSCRTHTTPClient (client_config = config )
244+
245+ # Mock HTTP/2 connection
246+ mock_connection = AsyncMock ()
247+ mock_connection .version = crt_http .HttpVersion .Http2
248+
249+ # Should not raise
250+ await client ._validate_connection (mock_connection )
251+
252+
253+ async def test_connection_pooling () -> None :
254+ """Test that connections are pooled and reused."""
255+ client = AWSCRTHTTPClient ()
256+ url = URI (scheme = "https" , host = "example.com" )
257+
258+ # Mock connection
259+ mock_connection = AsyncMock ()
260+ mock_connection .version = crt_http .HttpVersion .Http2
261+ # is_open() should be a regular method, not async
262+ mock_connection .is_open = Mock (return_value = True )
263+
264+ with patch ("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new" ) as mock_new :
265+ mock_new .return_value = mock_connection
266+
267+ # First call should create new connection
268+ conn1 = await client ._get_connection (url )
269+ assert mock_new .call_count == 1
270+
271+ # Second call should reuse connection
272+ conn2 = await client ._get_connection (url )
273+ assert mock_new .call_count == 1 # Not called again
274+ assert conn1 is conn2
275+
276+
277+ async def test_connection_pooling_different_hosts () -> None :
278+ """Test that different hosts get different connections."""
279+ client = AWSCRTHTTPClient ()
280+ url1 = URI (scheme = "https" , host = "example1.com" )
281+ url2 = URI (scheme = "https" , host = "example2.com" )
282+
283+ # Create two distinct mock connections
284+ mock_conn1 = AsyncMock ()
285+ mock_conn1 .version = crt_http .HttpVersion .Http2
286+ mock_conn1 .is_open = Mock (return_value = True )
287+
288+ mock_conn2 = AsyncMock ()
289+ mock_conn2 .version = crt_http .HttpVersion .Http2
290+ mock_conn2 .is_open = Mock (return_value = True )
291+
292+ with patch ("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new" ) as mock_new :
293+ mock_new .side_effect = [mock_conn1 , mock_conn2 ]
294+
295+ conn1 = await client ._get_connection (url1 )
296+ conn2 = await client ._get_connection (url2 )
297+
298+ assert mock_new .call_count == 2
299+ assert conn1 is mock_conn1
300+ assert conn2 is mock_conn2
301+ assert conn1 is not conn2
302+
303+
304+ async def test_connection_pooling_closed_connection () -> None :
305+ """Test that closed connections are replaced."""
306+ client = AWSCRTHTTPClient ()
307+ url = URI (scheme = "https" , host = "example.com" )
308+
309+ mock_connection1 = AsyncMock ()
310+ mock_connection1 .version = crt_http .HttpVersion .Http2
311+ mock_connection1 .is_open = Mock (return_value = False ) # Closed
312+
313+ mock_connection2 = AsyncMock ()
314+ mock_connection2 .version = crt_http .HttpVersion .Http2
315+ mock_connection2 .is_open = Mock (return_value = True )
316+
317+ with patch ("smithy_http.aio.crt.AIOHttpClientConnectionUnified.new" ) as mock_new :
318+ mock_new .side_effect = [mock_connection1 , mock_connection2 ]
319+
320+ # First call
321+ conn1 = await client ._get_connection (url )
322+ assert conn1 is mock_connection1
323+
324+ # Connection is now closed, should create new one
325+ conn2 = await client ._get_connection (url )
326+ assert conn2 is mock_connection2
327+ assert mock_new .call_count == 2
328+
329+
330+ async def test_response_chunks () -> None :
331+ """Test reading response body chunks."""
332+ mock_stream = AsyncMock ()
333+ mock_stream .get_next_response_chunk .side_effect = [
334+ b"chunk1" ,
335+ b"chunk2" ,
336+ b"chunk3" ,
337+ b"" , # End of stream
338+ ]
339+
340+ response = AWSCRTHTTPResponse (status = 200 , fields = Fields (), stream = mock_stream )
341+
342+ chunks : list [bytes ] = []
343+ async for chunk in response .chunks ():
344+ chunks .append (chunk )
345+
346+ assert chunks == [b"chunk1" , b"chunk2" , b"chunk3" ]
347+
348+
349+ async def test_response_body_property () -> None :
350+ """Test that body property returns chunks."""
351+ mock_stream = AsyncMock ()
352+ mock_stream .get_next_response_chunk .side_effect = [b"data" , b"" ]
353+
354+ response = AWSCRTHTTPResponse (status = 200 , fields = Fields (), stream = mock_stream )
355+
356+ chunks : list [bytes ] = []
357+ async for chunk in response .body :
358+ chunks .append (chunk )
359+
360+ assert chunks == [b"data" ]
361+
362+
363+ def test_response_properties () -> None :
364+ """Test response property accessors."""
365+ fields = Fields ()
366+ fields .set_field (Field (name = "content-type" , values = ["application/json" ]))
367+
368+ mock_stream = Mock ()
369+ response = AWSCRTHTTPResponse (status = 404 , fields = fields , stream = mock_stream )
370+
371+ assert response .status == 404
372+ assert response .fields == fields
373+ assert response .reason is None
0 commit comments