Skip to content

Commit 7755ea3

Browse files
Add HTTP ShapeDeserializer
1 parent eedcbf3 commit 7755ea3

File tree

1 file changed

+198
-0
lines changed

1 file changed

+198
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
from collections.abc import Callable
2+
from typing import TYPE_CHECKING
3+
from decimal import Decimal
4+
import datetime
5+
6+
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
7+
from smithy_core.codecs import Codec
8+
from smithy_core.schemas import Schema
9+
from smithy_core.traits import (
10+
HTTPTrait,
11+
HTTPHeaderTrait,
12+
HTTPPrefixHeadersTrait,
13+
HTTPPayloadTrait,
14+
HTTPResponseCodeTrait,
15+
TimestampFormatTrait,
16+
)
17+
from smithy_core.utils import strict_parse_bool, strict_parse_float, ensure_utc
18+
from smithy_core.types import TimestampFormat
19+
from smithy_core.shapes import ShapeType
20+
21+
from .aio.interfaces import HTTPResponse
22+
23+
from .interfaces import Field, Fields
24+
25+
if TYPE_CHECKING:
26+
from smithy_core.interfaces import StreamingBlob as _Stream
27+
from smithy_core.aio.interfaces import StreamingBlob as _AsyncStream
28+
from smithy_core.interfaces import BytesReader as _BytesReader
29+
30+
31+
__all__ = ["HTTPResponseDeserializer"]
32+
33+
34+
class HTTPResponseDeserializer(SpecificShapeDeserializer):
35+
# Note: caller will have to read the body if it's async and not streaming
36+
def __init__(
37+
self,
38+
payload_codec: Codec,
39+
http_trait: HTTPTrait,
40+
response: HTTPResponse,
41+
body: "_Stream | None" = None,
42+
) -> None:
43+
self._payload_codec = payload_codec
44+
self._response = response
45+
self._http_trait = http_trait
46+
if isinstance(body, bytearray):
47+
body = bytes(body)
48+
self._body = body
49+
50+
def read_struct(
51+
self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None]
52+
) -> None:
53+
status = self._response.status
54+
if status != self._http_trait.code and status >= 300:
55+
# TODO: handle exceptions
56+
raise Exception()
57+
58+
has_body = False
59+
payload_member: Schema | None = None
60+
61+
for member in schema.members.values():
62+
if (trait := member.get_trait(HTTPHeaderTrait)) is not None:
63+
if (
64+
header := self._response.fields.entries.get(trait.key.lower())
65+
) is not None:
66+
if member.shape_type is ShapeType.LIST:
67+
consumer(member, HTTPHeaderListDeserializer(header))
68+
else:
69+
consumer(member, HTTPHeaderDeserializer(header.as_string()))
70+
elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None:
71+
consumer(member, HTTPHeaderMapDeserializer(self._response.fields))
72+
elif HTTPPayloadTrait in member:
73+
has_body = True
74+
payload_member = member
75+
elif HTTPResponseCodeTrait in member:
76+
consumer(member, HTTPResponseCodeDeserializer(self._response.status))
77+
else:
78+
has_body = True
79+
80+
if has_body:
81+
deserializer = self._create_payload_deserializer(payload_member)
82+
if payload_member is not None:
83+
consumer(payload_member, deserializer)
84+
else:
85+
deserializer.read_struct(schema, consumer)
86+
87+
def _create_payload_deserializer(
88+
self, payload_member: Schema | None
89+
) -> ShapeDeserializer:
90+
body = self._body if self._body is not None else self._response.body
91+
if payload_member is not None and payload_member.shape_type in (
92+
ShapeType.BLOB,
93+
ShapeType.STRING,
94+
):
95+
return RawPayloadDeserializer(body)
96+
97+
return self._payload_codec.create_deserializer(body)
98+
99+
100+
class HTTPHeaderDeserializer(SpecificShapeDeserializer):
101+
def __init__(self, value: str) -> None:
102+
self._value = value
103+
104+
def read_boolean(self, schema: Schema) -> bool:
105+
return strict_parse_bool(self._value)
106+
107+
def read_byte(self, schema: Schema) -> int:
108+
return self.read_integer(schema)
109+
110+
def read_short(self, schema: Schema) -> int:
111+
return self.read_integer(schema)
112+
113+
def read_integer(self, schema: Schema) -> int:
114+
return int(self._value)
115+
116+
def read_long(self, schema: Schema) -> int:
117+
return self.read_integer(schema)
118+
119+
def read_big_integer(self, schema: Schema) -> int:
120+
return self.read_integer(schema)
121+
122+
def read_float(self, schema: Schema) -> float:
123+
return strict_parse_float(self._value)
124+
125+
def read_double(self, schema: Schema) -> float:
126+
return self.read_float(schema)
127+
128+
def read_big_decimal(self, schema: Schema) -> Decimal:
129+
return Decimal(self._value).canonical()
130+
131+
def read_string(self, schema: Schema) -> str:
132+
return self._value
133+
134+
def read_timestamp(self, schema: Schema) -> datetime.datetime:
135+
format = TimestampFormat.HTTP_DATE
136+
if (trait := schema.get_trait(TimestampFormatTrait)) is not None:
137+
format = trait.format
138+
return ensure_utc(format.deserialize(self._value))
139+
140+
141+
class HTTPHeaderListDeserializer(SpecificShapeDeserializer):
142+
def __init__(self, field: Field) -> None:
143+
self._field = field
144+
145+
def read_list(
146+
self, schema: Schema, consumer: Callable[["ShapeDeserializer"], None]
147+
) -> None:
148+
for value in self._field.values:
149+
consumer(HTTPHeaderDeserializer(value))
150+
151+
152+
class HTTPHeaderMapDeserializer(SpecificShapeDeserializer):
153+
def __init__(self, fields: Fields) -> None:
154+
self._fields = fields
155+
156+
def read_map(
157+
self,
158+
schema: Schema,
159+
consumer: Callable[[str, "ShapeDeserializer"], None],
160+
) -> None:
161+
prefix = schema.expect_trait(HTTPPrefixHeadersTrait).prefix.lower()
162+
for field in self._fields:
163+
if field.name.lower().startswith(prefix):
164+
consumer(
165+
field.name[len(prefix) :], HTTPHeaderDeserializer(field.as_string())
166+
)
167+
168+
169+
class HTTPResponseCodeDeserializer(SpecificShapeDeserializer):
170+
def __init__(self, response_code: int) -> None:
171+
self._response_code = response_code
172+
173+
def read_byte(self, schema: Schema) -> int:
174+
return self._response_code
175+
176+
def read_short(self, schema: Schema) -> int:
177+
return self._response_code
178+
179+
def read_integer(self, schema: Schema) -> int:
180+
return self._response_code
181+
182+
183+
class RawPayloadDeserializer(SpecificShapeDeserializer):
184+
def __init__(self, payload: "bytes | _BytesReader") -> None:
185+
self._payload = payload
186+
187+
def read_string(self, schema: Schema) -> str:
188+
if not isinstance(self._payload, bytes):
189+
self._payload = self._payload.read()
190+
return self._payload.decode("utf-8")
191+
192+
def read_blob(self, schema: Schema) -> bytes:
193+
if not isinstance(self._payload, bytes):
194+
self._payload = self._payload.read()
195+
return self._payload
196+
197+
def read_data_stream(self, schema: Schema) -> "_AsyncStream":
198+
return self._payload

0 commit comments

Comments
 (0)