Skip to content

Commit 5520941

Browse files
Centralize http binding matching and omit empty payloads
This changes the way members are classified to use a separate binding matcher. This lets us use the match statement rather than if-else chains and ensures we're minimizing the number of times we have to iterate over the member list. This also adds the ability to select whether a payload should be omitted if nothing is bound to it.
1 parent 81baaa8 commit 5520941

File tree

6 files changed

+488
-102
lines changed

6 files changed

+488
-102
lines changed

packages/smithy-core/src/smithy_core/retries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,11 @@ def refresh_retry_token_for_retry(
234234
if retry_count >= self.max_attempts:
235235
raise RetryError(
236236
f"Reached maximum number of allowed attempts: {self.max_attempts}"
237-
)
237+
) from error
238238
retry_delay = self.backoff_strategy.compute_next_backoff_delay(retry_count)
239239
return SimpleRetryToken(retry_count=retry_count, retry_delay=retry_delay)
240240
else:
241-
raise RetryError(f"Error is not retryable: {error}")
241+
raise RetryError(f"Error is not retryable: {error}") from error
242242

243243
def record_success(self, *, token: retries_interface.RetryToken) -> None:
244244
"""Not used by this retry strategy."""
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
from dataclasses import dataclass
4+
from enum import Enum
5+
6+
from smithy_core.schemas import Schema
7+
from smithy_core.shapes import ShapeType
8+
from smithy_core.traits import (
9+
ErrorFault,
10+
ErrorTrait,
11+
HostLabelTrait,
12+
HTTPErrorTrait,
13+
HTTPHeaderTrait,
14+
HTTPLabelTrait,
15+
HTTPPayloadTrait,
16+
HTTPPrefixHeadersTrait,
17+
HTTPQueryParamsTrait,
18+
HTTPQueryTrait,
19+
HTTPResponseCodeTrait,
20+
StreamingTrait,
21+
)
22+
23+
24+
class Binding(Enum):
25+
"""HTTP binding locations."""
26+
27+
HEADER = 0
28+
"""Indicates the member is bound to a header."""
29+
30+
QUERY = 1
31+
"""Indicates the member is bound to a query parameter."""
32+
33+
PAYLOAD = 2
34+
"""Indicates the member is bound to the entire HTTP payload."""
35+
36+
BODY = 3
37+
"""Indicates the member is a property in the HTTP payload structure."""
38+
39+
LABEL = 4
40+
"""Indicates the member is bound to a path segment in the URI."""
41+
42+
STATUS = 5
43+
"""Indicates the member is bound to the response status code."""
44+
45+
PREFIX_HEADERS = 6
46+
"""Indicates the member is bound to multiple headers with a shared prefix."""
47+
48+
QUERY_PARAMS = 7
49+
"""Indicates the member is bound to the query string as multiple key-value pairs."""
50+
51+
HOST = 8
52+
"""Indicates the member is bound to a prefix to the host AND to the body."""
53+
54+
55+
@dataclass(init=False)
56+
class _BindingMatcher:
57+
bindings: list[Binding]
58+
"""A list of bindings where the index matches the index of the member schema."""
59+
60+
response_status: int
61+
"""The default response status code."""
62+
63+
has_body: bool
64+
"""Whether the HTTP message has members bound to the body."""
65+
66+
has_payload: bool
67+
"""Whether the HTTP message has a member bound to the entire payload."""
68+
69+
payload_member: Schema | None
70+
"""The member bound to the payload, if one exists."""
71+
72+
event_stream_member: Schema | None
73+
"""The member bound to the event stream, if one exists."""
74+
75+
def __init__(self, struct: Schema, response_status: int) -> None:
76+
self.response_status = response_status
77+
found_body = False
78+
found_payload = False
79+
self.bindings = [Binding.BODY] * len(struct.members)
80+
self.payload_member = None
81+
self.event_stream_member = None
82+
83+
for member in struct.members.values():
84+
binding = self._do_match(member)
85+
self.bindings[member.expect_member_index()] = binding
86+
found_body = (
87+
found_body or binding is Binding.BODY or binding is Binding.HOST
88+
)
89+
if binding is Binding.PAYLOAD:
90+
found_payload = True
91+
self.payload_member = member
92+
if (
93+
StreamingTrait.id in member.traits
94+
and member.shape_type is ShapeType.UNION
95+
):
96+
self.event_stream_member = member
97+
98+
self.has_body = found_body
99+
self.has_payload = found_payload
100+
101+
def should_write_body(self, omit_empty_payload: bool) -> bool:
102+
"""Determines whether a body should be written.
103+
104+
:param omit_empty_payload: Whether a body should be skipped in the case of an
105+
empty payload.
106+
"""
107+
return self.has_body or (not omit_empty_payload and not self.has_payload)
108+
109+
def match(self, member: Schema) -> Binding:
110+
"""Determines which part of the HTTP message the given member is bound to."""
111+
return self.bindings[member.expect_member_index()]
112+
113+
def _do_match(self, member: Schema) -> Binding: ...
114+
115+
116+
@dataclass(init=False)
117+
class RequestBindingMatcher(_BindingMatcher):
118+
"""Matches structure members to HTTP request binding locations."""
119+
120+
def __init__(self, struct: Schema) -> None:
121+
"""Initialize a RequestBindingMatcher.
122+
123+
:param struct: The structure to examine for HTTP bindings.
124+
"""
125+
super().__init__(struct, -1)
126+
127+
def _do_match(self, member: Schema) -> Binding:
128+
if HTTPLabelTrait.id in member.traits:
129+
return Binding.LABEL
130+
if HTTPQueryTrait.id in member.traits:
131+
return Binding.QUERY
132+
if HTTPQueryParamsTrait.id in member.traits:
133+
return Binding.QUERY_PARAMS
134+
if HTTPHeaderTrait.id in member.traits:
135+
return Binding.HEADER
136+
if HTTPPrefixHeadersTrait.id in member.traits:
137+
return Binding.PREFIX_HEADERS
138+
if HTTPPayloadTrait.id in member.traits:
139+
return Binding.PAYLOAD
140+
if HostLabelTrait.id in member.traits:
141+
return Binding.HOST
142+
return Binding.BODY
143+
144+
145+
@dataclass(init=False)
146+
class ResponseBindingMatcher(_BindingMatcher):
147+
"""Matches structure members to HTTP response binding locations."""
148+
149+
def __init__(self, struct: Schema) -> None:
150+
"""Initialize a ResponseBindingMatcher.
151+
152+
:param struct: The structure to examine for HTTP bindings.
153+
"""
154+
super().__init__(struct, self._compute_response(struct))
155+
156+
def _compute_response(self, struct: Schema) -> int:
157+
if (http_error := struct.get_trait(HTTPErrorTrait)) is not None:
158+
return http_error.code
159+
if (error := struct.get_trait(ErrorTrait)) is not None:
160+
return 400 if error.fault is ErrorFault.CLIENT else 500
161+
return -1
162+
163+
def _do_match(self, member: Schema) -> Binding:
164+
if HTTPResponseCodeTrait.id in member.traits:
165+
return Binding.STATUS
166+
if HTTPHeaderTrait.id in member.traits:
167+
return Binding.HEADER
168+
if HTTPPrefixHeadersTrait.id in member.traits:
169+
return Binding.PREFIX_HEADERS
170+
if HTTPPayloadTrait.id in member.traits:
171+
return Binding.PAYLOAD
172+
return Binding.BODY

packages/smithy-http/src/smithy_http/deserializers.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
from smithy_core.shapes import ShapeType
1414
from smithy_core.traits import (
1515
HTTPHeaderTrait,
16-
HTTPPayloadTrait,
1716
HTTPPrefixHeadersTrait,
18-
HTTPResponseCodeTrait,
1917
HTTPTrait,
2018
TimestampFormatTrait,
2119
)
2220
from smithy_core.types import TimestampFormat
2321
from smithy_core.utils import ensure_utc, strict_parse_bool, strict_parse_float
2422

2523
from .aio.interfaces import HTTPResponse
24+
from .bindings import Binding, ResponseBindingMatcher
2625
from .interfaces import Field, Fields
2726

2827
if TYPE_CHECKING:
@@ -61,47 +60,49 @@ def __init__(
6160
def read_struct(
6261
self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None]
6362
) -> None:
64-
has_body = False
65-
payload_member: Schema | None = None
63+
binding_matcher = ResponseBindingMatcher(schema)
6664

6765
for member in schema.members.values():
68-
if (trait := member.get_trait(HTTPHeaderTrait)) is not None:
69-
header = self._response.fields.entries.get(trait.key.lower())
70-
if header is not None:
71-
if member.shape_type is ShapeType.LIST:
72-
consumer(member, HTTPHeaderListDeserializer(header))
73-
else:
74-
consumer(member, HTTPHeaderDeserializer(header.as_string()))
75-
elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None:
76-
consumer(
77-
member,
78-
HTTPHeaderMapDeserializer(self._response.fields, trait.prefix),
79-
)
80-
elif HTTPPayloadTrait in member:
81-
has_body = True
82-
payload_member = member
83-
elif HTTPResponseCodeTrait in member:
84-
consumer(member, HTTPResponseCodeDeserializer(self._response.status))
85-
else:
86-
has_body = True
87-
88-
if has_body:
89-
deserializer = self._create_payload_deserializer(payload_member)
90-
if payload_member is not None:
91-
consumer(payload_member, deserializer)
92-
else:
93-
deserializer.read_struct(schema, consumer)
94-
95-
def _create_payload_deserializer(
96-
self, payload_member: Schema | None
97-
) -> ShapeDeserializer:
98-
body = self._body if self._body is not None else self._response.body
99-
if payload_member is not None and payload_member.shape_type in (
100-
ShapeType.BLOB,
101-
ShapeType.STRING,
102-
):
66+
match binding_matcher.match(member):
67+
case Binding.HEADER:
68+
trait = member.expect_trait(HTTPHeaderTrait)
69+
header = self._response.fields.entries.get(trait.key.lower())
70+
if header is not None:
71+
if member.shape_type is ShapeType.LIST:
72+
consumer(member, HTTPHeaderListDeserializer(header))
73+
else:
74+
consumer(member, HTTPHeaderDeserializer(header.as_string()))
75+
case Binding.PREFIX_HEADERS:
76+
trait = member.expect_trait(HTTPPrefixHeadersTrait)
77+
consumer(
78+
member,
79+
HTTPHeaderMapDeserializer(self._response.fields, trait.prefix),
80+
)
81+
case Binding.STATUS:
82+
consumer(
83+
member, HTTPResponseCodeDeserializer(self._response.status)
84+
)
85+
case Binding.PAYLOAD:
86+
assert binding_matcher.payload_member is not None # noqa: S101
87+
deserializer = self._create_payload_deserializer(
88+
binding_matcher.payload_member
89+
)
90+
consumer(binding_matcher.payload_member, deserializer)
91+
case _:
92+
pass
93+
94+
if binding_matcher.has_body:
95+
deserializer = self._create_body_deserializer()
96+
deserializer.read_struct(schema, consumer)
97+
98+
def _create_payload_deserializer(self, payload_member: Schema) -> ShapeDeserializer:
99+
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
100+
body = self._body if self._body is not None else self._response.body
103101
return RawPayloadDeserializer(body)
102+
return self._create_body_deserializer()
104103

104+
def _create_body_deserializer(self):
105+
body = self._body if self._body is not None else self._response.body
105106
if not is_streaming_blob(body):
106107
raise UnsupportedStreamError(
107108
"Unable to read async stream. This stream must be buffered prior "

0 commit comments

Comments
 (0)