diff --git a/packages/smithy-http/src/smithy_http/aio/crt.py b/packages/smithy-http/src/smithy_http/aio/crt.py index 88531e95a..58fa3ebae 100644 --- a/packages/smithy-http/src/smithy_http/aio/crt.py +++ b/packages/smithy-http/src/smithy_http/aio/crt.py @@ -292,8 +292,14 @@ async def _marshal_request( """Create :py:class:`awscrt.http.HttpRequest` from :py:class:`smithy_http.aio.HTTPRequest`""" headers_list = [] + if "Host" not in request.fields: + request.fields.set_field( + Field(name="Host", values=[request.destination.host]) + ) + for fld in request.fields.entries.values(): - if fld.kind != FieldPosition.HEADER: + # TODO: Use literal values for "header"/"trailer". + if fld.kind.value != FieldPosition.HEADER.value: continue for val in fld.values: headers_list.append((fld.name, val)) diff --git a/packages/smithy-http/src/smithy_http/interfaces/__init__.py b/packages/smithy-http/src/smithy_http/interfaces/__init__.py index 0c3ac3cc0..2d1445199 100644 --- a/packages/smithy-http/src/smithy_http/interfaces/__init__.py +++ b/packages/smithy-http/src/smithy_http/interfaces/__init__.py @@ -98,6 +98,10 @@ def __len__(self) -> int: """Get total number of Field entries.""" ... + def __contains__(self, key: str) -> bool: + """Allows in/not in checks on Field entries.""" + ... + def get_by_type(self, kind: FieldPosition) -> list[Field]: """Helper function for retrieving specific types of fields. diff --git a/packages/smithy-http/tests/unit/aio/test_crt.py b/packages/smithy-http/tests/unit/aio/test_crt.py index 3ca5197cc..d5c695417 100644 --- a/packages/smithy-http/tests/unit/aio/test_crt.py +++ b/packages/smithy-http/tests/unit/aio/test_crt.py @@ -1,9 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from copy import deepcopy +from io import BytesIO import pytest +from smithy_core import URI +from smithy_http import Fields +from smithy_http.aio import HTTPRequest from smithy_http.aio.crt import AWSCRTHTTPClient, BufferableByteStream @@ -12,6 +16,22 @@ def test_deepcopy_client() -> None: deepcopy(client) +async def test_client_marshal_request() -> None: + client = AWSCRTHTTPClient() + request = HTTPRequest( + method="GET", + destination=URI( + host="example.com", path="/path", query="key1=value1&key2=value2" + ), + body=BytesIO(), + fields=Fields(), + ) + crt_request = await client._marshal_request(request) # type: ignore + assert crt_request.headers.get("host") == "example.com" # type: ignore + assert crt_request.method == "GET" # type: ignore + assert crt_request.path == "/path?key1=value1&key2=value2" # type: ignore + + def test_stream_write() -> None: stream = BufferableByteStream() stream.write(b"foo")