Skip to content

Commit c5c215e

Browse files
Deserialize string/blob event payloads directly
1 parent 4d51e95 commit c5c215e

File tree

3 files changed

+71
-34
lines changed

3 files changed

+71
-34
lines changed
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3+
from smithy_core.schemas import Schema
4+
5+
from .traits import EVENT_PAYLOAD_TRAIT
36

47
INITIAL_REQUEST_EVENT_TYPE = "initial-request"
58
INITIAL_RESPONSE_EVENT_TYPE = "initial-response"
9+
10+
11+
def get_payload_member(schema: Schema) -> Schema | None:
12+
for member in schema.members.values():
13+
if EVENT_PAYLOAD_TRAIT in member.traits:
14+
return member
15+
return None

python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@
1212
SpecificShapeDeserializer,
1313
)
1414
from smithy_core.schemas import Schema
15+
from smithy_core.shapes import ShapeType
1516
from smithy_core.utils import expect_type
1617
from smithy_event_stream.aio.interfaces import AsyncEventReceiver
1718

1819
from ..events import HEADERS_DICT, Event
1920
from ..exceptions import EventError, UnmodeledEventError
20-
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
21-
from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT
21+
from . import (
22+
INITIAL_REQUEST_EVENT_TYPE,
23+
INITIAL_RESPONSE_EVENT_TYPE,
24+
get_payload_member,
25+
)
26+
from .traits import EVENT_HEADER_TRAIT
2227

2328
INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE)
2429

@@ -69,25 +74,24 @@ def read_struct(
6974
) -> None:
7075
headers = self._event.message.headers
7176

72-
payload_deserializer = None
73-
if self._event.message.payload:
74-
payload_deserializer = self._payload_codec.create_deserializer(
75-
self._event.message.payload
76-
)
77-
78-
message_deserializer = EventMessageDeserializer(headers, payload_deserializer)
79-
8077
match headers.get(":message-type"):
8178
case "event":
8279
member_name = expect_type(str, headers[":event-type"])
8380
if member_name in INITIAL_MESSAGE_TYPES:
8481
# If it's an initial message, skip straight to deserialization.
82+
message_deserializer = self._create_deserializer(schema, headers)
8583
message_deserializer.read_struct(schema, consumer)
8684
else:
87-
consumer(schema.members[member_name], message_deserializer)
85+
member_schema = schema.members[member_name]
86+
message_deserializer = self._create_deserializer(
87+
member_schema, headers
88+
)
89+
consumer(member_schema, message_deserializer)
8890
case "exception":
8991
member_name = expect_type(str, headers[":exception-type"])
90-
consumer(schema.members[member_name], message_deserializer)
92+
member_schema = schema.members[member_name]
93+
message_deserializer = self._create_deserializer(member_schema, headers)
94+
consumer(member_schema, message_deserializer)
9195
case "error":
9296
# The `application/vnd.amazon.eventstream` format allows for explicitly
9397
# unmodeled exceptions. These exceptions MUST have the `:error-code`
@@ -99,13 +103,49 @@ def read_struct(
99103
case _:
100104
raise EventError(f"Unknown event structure: {self._event}")
101105

106+
def _create_deserializer(
107+
self, schema: Schema, headers: HEADERS_DICT
108+
) -> ShapeDeserializer:
109+
payload_member = get_payload_member(schema)
110+
payload_deserializer = self._create_payload_deserializer(payload_member)
111+
return EventMessageDeserializer(headers, payload_deserializer, payload_member)
112+
113+
def _create_payload_deserializer(
114+
self, payload_member: Schema | None
115+
) -> ShapeDeserializer | None:
116+
if not self._event.message.payload:
117+
return
118+
119+
if payload_member is not None and payload_member.shape_type in (
120+
ShapeType.BLOB,
121+
ShapeType.STRING,
122+
):
123+
return RawPayloadDeserializer(self._event.message.payload)
124+
125+
return self._payload_codec.create_deserializer(self._event.message.payload)
126+
127+
128+
class RawPayloadDeserializer(SpecificShapeDeserializer):
129+
def __init__(self, payload: bytes) -> None:
130+
self._payload = payload
131+
132+
def read_string(self, schema: Schema) -> str:
133+
return self._payload.decode("utf-8")
134+
135+
def read_blob(self, schema: Schema) -> bytes:
136+
return self._payload
137+
102138

103139
class EventMessageDeserializer(SpecificShapeDeserializer):
104140
def __init__(
105-
self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer | None
141+
self,
142+
headers: HEADERS_DICT,
143+
payload_deserializer: ShapeDeserializer | None,
144+
payload_member: Schema | None,
106145
) -> None:
107146
self._headers = headers
108147
self._payload_deserializer = payload_deserializer
148+
self._payload_member = payload_member
109149

110150
def read_struct(
111151
self,
@@ -119,17 +159,11 @@ def read_struct(
119159
consumer(member_schema, headers_deserializer)
120160

121161
if self._payload_deserializer:
122-
if (payload_member := self._get_payload_member(schema)) is not None:
123-
consumer(payload_member, self._payload_deserializer)
162+
if self._payload_member is not None:
163+
consumer(self._payload_member, self._payload_deserializer)
124164
else:
125165
self._payload_deserializer.read_struct(schema, consumer)
126166

127-
def _get_payload_member(self, schema: "Schema") -> "Schema | None":
128-
for member in schema.members.values():
129-
if EVENT_PAYLOAD_TRAIT in member.traits:
130-
return member
131-
return None
132-
133167

134168
class EventHeaderDeserializer(SpecificShapeDeserializer):
135169
def __init__(self, headers: HEADERS_DICT) -> None:

python-packages/aws-event-stream/aws_event_stream/_private/serializers.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,12 @@
2323

2424
from ..events import EventHeaderEncoder, EventMessage
2525
from ..exceptions import InvalidHeaderValue
26-
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
27-
from .traits import (
28-
ERROR_TRAIT,
29-
EVENT_HEADER_TRAIT,
30-
EVENT_PAYLOAD_TRAIT,
31-
MEDIA_TYPE_TRAIT,
26+
from . import (
27+
INITIAL_REQUEST_EVENT_TYPE,
28+
INITIAL_RESPONSE_EVENT_TYPE,
29+
get_payload_member,
3230
)
31+
from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, MEDIA_TYPE_TRAIT
3332

3433
_DEFAULT_STRING_CONTENT_TYPE = "text/plain"
3534
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"
@@ -129,7 +128,7 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
129128

130129
media_type = self._payload_codec.media_type
131130

132-
if (payload_member := self._get_payload_member(schema)) is not None:
131+
if (payload_member := get_payload_member(schema)) is not None:
133132
media_type = self._get_payload_media_type(payload_member, media_type)
134133
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
135134
payload_serializer = RawPayloadSerializer(payload)
@@ -146,12 +145,6 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
146145
headers_bytes=headers_encoder.get_result(), payload=payload_bytes
147146
)
148147

149-
def _get_payload_member(self, schema: Schema) -> Schema | None:
150-
for member in schema.members.values():
151-
if EVENT_PAYLOAD_TRAIT in member.traits:
152-
return member
153-
return None
154-
155148
def _get_payload_media_type(self, schema: Schema, default: str) -> str:
156149
if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None:
157150
return expect_type(str, media_type.value)

0 commit comments

Comments
 (0)