Skip to content

Commit 4f3304e

Browse files
Add HTTP ShapeDeserializer
1 parent eedcbf3 commit 4f3304e

File tree

1 file changed

+215
-0
lines changed

1 file changed

+215
-0
lines changed
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
from collections.abc import Callable
2+
from typing import TYPE_CHECKING, TypeGuard
3+
from decimal import Decimal
4+
import datetime
5+
from asyncio import iscoroutinefunction
6+
7+
from smithy_core.deserializers import ShapeDeserializer, SpecificShapeDeserializer
8+
from smithy_core.codecs import Codec
9+
from smithy_core.schemas import Schema
10+
from smithy_core.traits import (
11+
HTTPTrait,
12+
HTTPHeaderTrait,
13+
HTTPPrefixHeadersTrait,
14+
HTTPPayloadTrait,
15+
HTTPResponseCodeTrait,
16+
TimestampFormatTrait,
17+
)
18+
from smithy_core.utils import strict_parse_bool, strict_parse_float, ensure_utc
19+
from smithy_core.types import TimestampFormat
20+
from smithy_core.shapes import ShapeType
21+
from smithy_core.exceptions import UnsupportedStreamException
22+
23+
from .aio.interfaces import HTTPResponse
24+
25+
from .interfaces import Field, Fields
26+
27+
if TYPE_CHECKING:
28+
from smithy_core.interfaces import StreamingBlob as _SyncStream
29+
from smithy_core.aio.interfaces import StreamingBlob as _AsyncStream
30+
from smithy_core.interfaces import BytesReader as _BytesReader
31+
32+
33+
__all__ = ["HTTPResponseDeserializer"]
34+
35+
36+
def _is_sync_reader(stream: "_AsyncStream") -> TypeGuard["_BytesReader"]:
37+
return (
38+
read := getattr(stream, "read", None)
39+
) is not None and not iscoroutinefunction(read)
40+
41+
42+
def _is_sync_stream(stream: "_AsyncStream") -> TypeGuard["_SyncStream"]:
43+
return isinstance(stream, bytes | bytearray) or _is_sync_reader(stream)
44+
45+
46+
class HTTPResponseDeserializer(SpecificShapeDeserializer):
47+
# Note: caller will have to read the body if it's async and not streaming
48+
def __init__(
49+
self,
50+
payload_codec: Codec,
51+
http_trait: HTTPTrait,
52+
response: HTTPResponse,
53+
body: "_SyncStream | None" = None,
54+
) -> None:
55+
self._payload_codec = payload_codec
56+
self._response = response
57+
self._http_trait = http_trait
58+
self._body = body
59+
60+
def read_struct(
61+
self, schema: Schema, consumer: Callable[[Schema, ShapeDeserializer], None]
62+
) -> None:
63+
has_body = False
64+
payload_member: Schema | None = None
65+
66+
for member in schema.members.values():
67+
if (trait := member.get_trait(HTTPHeaderTrait)) is not None:
68+
header = self._response.fields.entries.get(trait.key.lower())
69+
if header is not None:
70+
if member.shape_type is ShapeType.LIST:
71+
consumer(member, HTTPHeaderListDeserializer(header))
72+
else:
73+
consumer(member, HTTPHeaderDeserializer(header.as_string()))
74+
elif (trait := member.get_trait(HTTPPrefixHeadersTrait)) is not None:
75+
consumer(
76+
member,
77+
HTTPHeaderMapDeserializer(trait.prefix, self._response.fields),
78+
)
79+
elif HTTPPayloadTrait in member:
80+
has_body = True
81+
payload_member = member
82+
elif HTTPResponseCodeTrait in member:
83+
consumer(member, HTTPResponseCodeDeserializer(self._response.status))
84+
else:
85+
has_body = True
86+
87+
if has_body:
88+
deserializer = self._create_payload_deserializer(payload_member)
89+
if payload_member is not None:
90+
consumer(payload_member, deserializer)
91+
else:
92+
deserializer.read_struct(schema, consumer)
93+
94+
def _create_payload_deserializer(
95+
self, payload_member: Schema | None
96+
) -> ShapeDeserializer:
97+
body = self._body if self._body is not None else self._response.body
98+
if payload_member is not None and payload_member.shape_type in (
99+
ShapeType.BLOB,
100+
ShapeType.STRING,
101+
):
102+
return RawPayloadDeserializer(body)
103+
104+
if not _is_sync_stream(body):
105+
raise UnsupportedStreamException()
106+
107+
if isinstance(body, bytearray):
108+
body = bytes(body)
109+
110+
return self._payload_codec.create_deserializer(body)
111+
112+
113+
class HTTPHeaderDeserializer(SpecificShapeDeserializer):
114+
def __init__(self, value: str) -> None:
115+
self._value = value
116+
117+
def read_boolean(self, schema: Schema) -> bool:
118+
return strict_parse_bool(self._value)
119+
120+
def read_byte(self, schema: Schema) -> int:
121+
return self.read_integer(schema)
122+
123+
def read_short(self, schema: Schema) -> int:
124+
return self.read_integer(schema)
125+
126+
def read_integer(self, schema: Schema) -> int:
127+
return int(self._value)
128+
129+
def read_long(self, schema: Schema) -> int:
130+
return self.read_integer(schema)
131+
132+
def read_big_integer(self, schema: Schema) -> int:
133+
return self.read_integer(schema)
134+
135+
def read_float(self, schema: Schema) -> float:
136+
return strict_parse_float(self._value)
137+
138+
def read_double(self, schema: Schema) -> float:
139+
return self.read_float(schema)
140+
141+
def read_big_decimal(self, schema: Schema) -> Decimal:
142+
return Decimal(self._value).canonical()
143+
144+
def read_string(self, schema: Schema) -> str:
145+
return self._value
146+
147+
def read_timestamp(self, schema: Schema) -> datetime.datetime:
148+
format = TimestampFormat.HTTP_DATE
149+
if (trait := schema.get_trait(TimestampFormatTrait)) is not None:
150+
format = trait.format
151+
return ensure_utc(format.deserialize(self._value))
152+
153+
154+
class HTTPHeaderListDeserializer(SpecificShapeDeserializer):
155+
def __init__(self, field: Field) -> None:
156+
self._field = field
157+
158+
def read_list(
159+
self, schema: Schema, consumer: Callable[["ShapeDeserializer"], None]
160+
) -> None:
161+
for value in self._field.values:
162+
consumer(HTTPHeaderDeserializer(value))
163+
164+
165+
class HTTPHeaderMapDeserializer(SpecificShapeDeserializer):
166+
def __init__(self, prefix: str, fields: Fields) -> None:
167+
self._prefix = prefix.lower()
168+
self._fields = fields
169+
170+
def read_map(
171+
self,
172+
schema: Schema,
173+
consumer: Callable[[str, "ShapeDeserializer"], None],
174+
) -> None:
175+
trim = len(self._prefix)
176+
for field in self._fields:
177+
if field.name.lower().startswith(self._prefix):
178+
consumer(field.name[trim:], HTTPHeaderDeserializer(field.as_string()))
179+
180+
181+
class HTTPResponseCodeDeserializer(SpecificShapeDeserializer):
182+
def __init__(self, response_code: int) -> None:
183+
self._response_code = response_code
184+
185+
def read_byte(self, schema: Schema) -> int:
186+
return self._response_code
187+
188+
def read_short(self, schema: Schema) -> int:
189+
return self._response_code
190+
191+
def read_integer(self, schema: Schema) -> int:
192+
return self._response_code
193+
194+
195+
class RawPayloadDeserializer(SpecificShapeDeserializer):
196+
def __init__(self, payload: "_AsyncStream") -> None:
197+
self._payload = payload
198+
199+
def read_string(self, schema: Schema) -> str:
200+
return self._consume_payload().decode("utf-8")
201+
202+
def read_blob(self, schema: Schema) -> bytes:
203+
return self._consume_payload()
204+
205+
def read_data_stream(self, schema: Schema) -> "_AsyncStream":
206+
return self._payload
207+
208+
def _consume_payload(self) -> bytes:
209+
if isinstance(self._payload, bytes):
210+
return self._payload
211+
if isinstance(self._payload, bytearray):
212+
return bytes(self._payload)
213+
if _is_sync_reader(self._payload):
214+
return self._payload.read()
215+
raise UnsupportedStreamException

0 commit comments

Comments
 (0)