Skip to content

Commit 3c86783

Browse files
Add HTTP ShapeDeserializer
1 parent 4247c7f commit 3c86783

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
from collections.abc import Callable
2+
from typing import TYPE_CHECKING
3+
from decimal import Decimal
4+
import datetime
5+
6+
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
7+
from smithy_core.codecs import Codec
8+
from smithy_core.schemas import Schema
9+
from smithy_core.traits import (
10+
HTTPTrait,
11+
HTTPHeaderTrait,
12+
HTTPPrefixHeadersTrait,
13+
HTTPPayloadTrait,
14+
HTTPResponseCodeTrait,
15+
TimestampFormatTrait,
16+
)
17+
from smithy_core.utils import strict_parse_bool, strict_parse_float, ensure_utc
18+
from smithy_core.types import TimestampFormat
19+
from smithy_core.shapes import ShapeType
20+
21+
from .aio.interfaces import HTTPResponse
22+
23+
from .interfaces import Field, Fields
24+
25+
if TYPE_CHECKING:
26+
from smithy_core.interfaces import StreamingBlob as _Stream
27+
from smithy_core.aio.interfaces import StreamingBlob as _AsyncStream
28+
from smithy_core.interfaces import BytesReader as _BytesReader
29+
30+
31+
__all__ = ["HTTPResponseDeserializer"]
32+
33+
34+
class HTTPResponseDeserializer(SpecificShapeDeserializer):
35+
# Note: caller will have to read the body if it's async and not streaming
36+
def __init__(
37+
self,
38+
payload_codec: Codec,
39+
http_trait: HTTPTrait,
40+
response: HTTPResponse,
41+
body: _Stream | None = None,
42+
) -> None:
43+
self._payload_codec = payload_codec
44+
self._response = response
45+
self._http_trait = http_trait
46+
if isinstance(body, bytearray):
47+
body = bytes(body)
48+
self._body = body
49+
50+
def read_struct(
51+
self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None]
52+
) -> None:
53+
status = self._response.status
54+
if status != self._http_trait.code and status >= 300:
55+
# TODO: handle exceptions
56+
raise Exception()
57+
58+
## add fields contains, fields get by prefix
59+
60+
has_body = False
61+
payload_member: Schema | None = None
62+
63+
for member in schema.members.values():
64+
if (trait := member.get_trait(HTTPHeaderTrait)) is not None:
65+
if (header := self._response.fields.entries.get(trait.key)) is not None:
66+
if schema.shape_type is ShapeType.LIST:
67+
consumer(member, HTTPHeaderListDeserializer(header))
68+
else:
69+
consumer(member, HTTPHeaderDeserializer(header.as_string()))
70+
elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None:
71+
consumer(member, HTTPHeaderMapDeserializer(self._response.fields))
72+
elif HTTPPayloadTrait in member:
73+
has_body = True
74+
payload_member = member
75+
elif HTTPResponseCodeTrait in member:
76+
consumer(member, HTTPResponseCodeDeserializer(self._response.status))
77+
else:
78+
has_body = True
79+
80+
if has_body and self._body is not None:
81+
if payload_member is not None:
82+
consumer(payload_member, RawPayloadDeserializer(self._body))
83+
else:
84+
payload_deserializer = self._payload_codec.create_deserializer(
85+
self._body
86+
)
87+
payload_deserializer.read_struct(schema, consumer)
88+
89+
90+
class HTTPHeaderDeserializer(SpecificShapeDeserializer):
91+
def __init__(self, value: str) -> None:
92+
self._value = value
93+
94+
def read_boolean(self, schema: Schema) -> bool:
95+
return strict_parse_bool(self._value)
96+
97+
def read_byte(self, schema: Schema) -> int:
98+
return self.read_integer(schema)
99+
100+
def read_short(self, schema: Schema) -> int:
101+
return self.read_integer(schema)
102+
103+
def read_integer(self, schema: Schema) -> int:
104+
return int(self._value)
105+
106+
def read_long(self, schema: Schema) -> int:
107+
return self.read_integer(schema)
108+
109+
def read_big_integer(self, schema: Schema) -> int:
110+
return self.read_integer(schema)
111+
112+
def read_float(self, schema: Schema) -> float:
113+
return strict_parse_float(self._value)
114+
115+
def read_double(self, schema: Schema) -> float:
116+
return self.read_float(schema)
117+
118+
def read_big_decimal(self, schema: Schema) -> Decimal:
119+
return Decimal(self._value).canonical()
120+
121+
def read_string(self, schema: Schema) -> str:
122+
return self._value
123+
124+
def read_timestamp(self, schema: Schema) -> datetime.datetime:
125+
format = TimestampFormat.HTTP_DATE
126+
if (trait := schema.get_trait(TimestampFormatTrait)) is not None:
127+
format = trait.format
128+
return ensure_utc(format.deserialize(self._value))
129+
130+
131+
class HTTPHeaderListDeserializer(SpecificShapeDeserializer):
132+
def __init__(self, field: Field) -> None:
133+
self._field = field
134+
135+
def read_list(
136+
self, schema: Schema, consumer: Callable[["ShapeDeserializer"], None]
137+
) -> None:
138+
for value in self._field.values:
139+
consumer(HTTPHeaderDeserializer(value))
140+
141+
142+
class HTTPHeaderMapDeserializer(SpecificShapeDeserializer):
143+
def __init__(self, fields: Fields) -> None:
144+
self._fields = fields
145+
146+
def read_map(
147+
self,
148+
schema: Schema,
149+
consumer: Callable[[str, "ShapeDeserializer"], None],
150+
) -> None:
151+
prefix = schema.expect_trait(HTTPPrefixHeadersTrait).prefix
152+
for field in self._fields:
153+
if field.name.startswith(prefix):
154+
consumer(field.name, HTTPHeaderDeserializer(field.as_string()))
155+
156+
157+
class HTTPResponseCodeDeserializer(SpecificShapeDeserializer):
158+
def __init__(self, response_code: int) -> None:
159+
self._response_code = response_code
160+
161+
def read_byte(self, schema: Schema) -> int:
162+
return self._response_code
163+
164+
def read_short(self, schema: Schema) -> int:
165+
return self._response_code
166+
167+
def read_integer(self, schema: Schema) -> int:
168+
return self._response_code
169+
170+
171+
class RawPayloadDeserializer(SpecificShapeDeserializer):
172+
def __init__(self, payload: bytes | _BytesReader) -> None:
173+
self._payload = payload
174+
175+
def read_string(self, schema: Schema) -> str:
176+
if not isinstance(self._payload, bytes):
177+
self._payload = self._payload.read()
178+
return self._payload.decode("utf-8")
179+
180+
def read_blob(self, schema: Schema) -> bytes:
181+
if not isinstance(self._payload, bytes):
182+
self._payload = self._payload.read()
183+
return self._payload
184+
185+
def read_data_stream(self, schema: Schema) -> _AsyncStream:
186+
return self._payload

0 commit comments

Comments
 (0)