Skip to content

Commit 4b6ba2c

Browse files
committed
comments
1 parent 17eb3b2 commit 4b6ba2c

File tree

5 files changed

+47
-12
lines changed

5 files changed

+47
-12
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0

packages/smithy-aws-core/src/smithy_aws_core/protocols/restjson.py renamed to packages/smithy-aws-core/src/smithy_aws_core/aio/protocols.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Literal
22

33
from smithy_aws_core.traits import RestJson1Trait
4-
from smithy_core.protocols import HttpBindingClientProtocol
4+
from smithy_http.aio.protocols import HttpBindingClientProtocol
55
from smithy_core.codecs import Codec
66
from smithy_core.shapes import ShapeID
77
from smithy_json import JSONCodec

packages/smithy-core/src/smithy_core/protocols.py renamed to packages/smithy-http/src/smithy_http/aio/protocols.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ def serialize_request[
7171
endpoint_trait=operation.schema.get_trait(EndpointTrait),
7272
)
7373

74-
input.serialize(
75-
serializer=serializer
76-
) # TODO: ensure serializer adds content-type
74+
input.serialize(serializer=serializer)
7775
request = serializer.result
7876

7977
if request is None:
@@ -95,7 +93,7 @@ async def deserialize_response[
9593
error_registry: TypeRegistry,
9694
context: TypedProperties,
9795
) -> OperationOutput:
98-
if not (200 <= response.status <= 299): # TODO: extract to utility
96+
if not (200 <= response.status <= 299):
9997
# TODO: implement error serde from type registry
10098
raise NotImplementedError
10199

packages/smithy-http/src/smithy_http/serializers.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
TimestampFormatTrait,
3030
EndpointTrait,
3131
HTTPErrorTrait,
32+
MediaTypeTrait,
33+
StreamingTrait,
3234
)
3335
from smithy_core.shapes import ShapeType
3436
from smithy_core.utils import serialize_float
@@ -83,22 +85,33 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
8385
if self._endpoint_trait is not None:
8486
host_prefix = self._endpoint_trait.host_prefix
8587

88+
content_type = self._payload_codec.media_type
89+
8690
if (payload_member := self._get_payload_member(schema)) is not None:
8791
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
92+
content_type = (
93+
"application/octet-stream"
94+
if payload_member.shape_type is ShapeType.BLOB
95+
else "text/plain"
96+
)
8897
payload_serializer = RawPayloadSerializer()
8998
binding_serializer = HTTPRequestBindingSerializer(
9099
payload_serializer, self._http_trait.path, host_prefix
91100
)
92101
yield binding_serializer
93102
payload = payload_serializer.payload
94103
else:
104+
if (media_type := payload_member.get_trait(MediaTypeTrait)) is not None:
105+
content_type = media_type.value
95106
payload = BytesIO()
96107
payload_serializer = self._payload_codec.create_serializer(payload)
97108
binding_serializer = HTTPRequestBindingSerializer(
98109
payload_serializer, self._http_trait.path, host_prefix
99110
)
100111
yield binding_serializer
101112
else:
113+
if self._get_eventstreaming_member(schema) is not None:
114+
content_type = "application/vnd.amazon.eventstream"
102115
payload = BytesIO()
103116
payload_serializer = self._payload_codec.create_serializer(payload)
104117
with payload_serializer.begin_struct(schema) as body_serializer:
@@ -112,9 +125,9 @@ def begin_struct(self, schema: Schema) -> Iterator[ShapeSerializer]:
112125
) is not None and not iscoroutinefunction(seek):
113126
seek(0)
114127

115-
# TODO: conditional on empty-ness and a param of the protocol?
128+
# TODO: conditional on empty-ness and based on the protocol
116129
headers = binding_serializer.header_serializer.headers
117-
headers.append(("content-type", self._payload_codec.media_type))
130+
headers.append(("content-type", content_type))
118131

119132
self.result = _HTTPRequest(
120133
method=self._http_trait.method,
@@ -136,6 +149,15 @@ def _get_payload_member(self, schema: Schema) -> Schema | None:
136149
return member
137150
return None
138151

152+
def _get_eventstreaming_member(self, schema: Schema) -> Schema | None:
153+
for member in schema.members.values():
154+
if (
155+
member.get_trait(StreamingTrait) is not None
156+
and member.shape_type is ShapeType.UNION
157+
):
158+
return member
159+
return None
160+
139161

140162
class HTTPRequestBindingSerializer(InterceptingSerializer):
141163
"""Delegates HTTP request bindings to binding-location-specific serializers."""

packages/smithy-http/tests/unit/test_serializers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,11 +1572,16 @@ def payload_cases() -> list[HTTPMessageTestCase]:
15721572
),
15731573
HTTPMessageTestCase(
15741574
HTTPStringPayload(payload="foo"),
1575-
HTTPMessage(body=b"foo"),
1575+
HTTPMessage(
1576+
fields=tuples_to_fields([("content-type", "text/plain")]), body=b"foo"
1577+
),
15761578
),
15771579
HTTPMessageTestCase(
15781580
HTTPBlobPayload(payload=b"\xde\xad\xbe\xef"),
1579-
HTTPMessage(body=b"\xde\xad\xbe\xef"),
1581+
HTTPMessage(
1582+
fields=tuples_to_fields([("content-type", "application/octet-stream")]),
1583+
body=b"\xde\xad\xbe\xef",
1584+
),
15801585
),
15811586
HTTPMessageTestCase(
15821587
HTTPStructuredPayload(payload=HTTPStringPayload(payload="foo")),
@@ -1589,7 +1594,10 @@ def async_streaming_payload_cases() -> list[HTTPMessageTestCase]:
15891594
return [
15901595
HTTPMessageTestCase(
15911596
HTTPStreamingPayload(payload=AsyncBytesReader(b"\xde\xad\xbe\xef")),
1592-
HTTPMessage(body=AsyncBytesReader(b"\xde\xad\xbe\xef")),
1597+
HTTPMessage(
1598+
fields=tuples_to_fields([("content-type", "application/octet-stream")]),
1599+
body=AsyncBytesReader(b"\xde\xad\xbe\xef"),
1600+
),
15931601
),
15941602
]
15951603

@@ -1625,8 +1633,10 @@ async def test_serialize_http_request(case: HTTPMessageTestCase) -> None:
16251633
actual_query = actual.destination.query or ""
16261634
expected_query = case.request.destination.query or ""
16271635
assert actual_query == expected_query
1628-
# set the content-type field here, otherwise cases would have to duplicate it everywhere
1629-
expected.fields.set_field(CONTENT_TYPE_FIELD)
1636+
# set the content-type field here, otherwise cases would have to duplicate it everywhere,
1637+
# but if the field is already set in the case, don't override it
1638+
if expected.fields.get(CONTENT_TYPE_FIELD.name) is None:
1639+
expected.fields.set_field(CONTENT_TYPE_FIELD)
16301640
assert actual.fields == expected.fields
16311641

16321642
if case.request.body:
@@ -1651,6 +1661,9 @@ async def test_serialize_http_response(case: HTTPMessageTestCase) -> None:
16511661
expected = case.request
16521662

16531663
assert actual is not None
1664+
# Remove content-type from expected, we're re-using the request cases for brevity
1665+
if expected.fields.get(CONTENT_TYPE_FIELD.name) is not None:
1666+
del expected.fields[CONTENT_TYPE_FIELD.name]
16541667
assert actual.fields == expected.fields
16551668
assert actual.status == expected.status
16561669

0 commit comments

Comments
 (0)