Skip to content

Commit 4d51e95

Browse files
Serialize string/blob event payloads directly
1 parent cca5234 commit 4d51e95

File tree

2 files changed

+115
-16
lines changed

2 files changed

+115
-16
lines changed

python-packages/aws-event-stream/aws_event_stream/_private/serializers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
131131

132132
if (payload_member := self._get_payload_member(schema)) is not None:
133133
media_type = self._get_payload_media_type(payload_member, media_type)
134+
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
135+
payload_serializer = RawPayloadSerializer(payload)
134136
yield EventStreamBindingSerializer(header_serializer, payload_serializer)
135137
else:
136138
with payload_serializer.begin_struct(schema) as body_serializer:
@@ -200,19 +202,30 @@ def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None:
200202
self._encoder.encode_timestamp(schema.expect_member_name(), value)
201203

202204

205+
class RawPayloadSerializer(SpecificShapeSerializer):
206+
def __init__(self, payload: BytesIO) -> None:
207+
self._payload = payload
208+
209+
def write_string(self, schema: "Schema", value: str) -> None:
210+
self._payload.write(value.encode("utf-8"))
211+
212+
def write_blob(self, schema: "Schema", value: bytes) -> None:
213+
self._payload.write(value)
214+
215+
203216
class EventStreamBindingSerializer(InterceptingSerializer):
204217
def __init__(
205218
self,
206219
header_serializer: EventHeaderSerializer,
207-
payload_serializer: ShapeSerializer,
220+
payload_struct_serializer: ShapeSerializer,
208221
) -> None:
209222
self._header_serializer = header_serializer
210-
self._payload_serializer = payload_serializer
223+
self._payload_struct_serializer = payload_struct_serializer
211224

212225
def before(self, schema: "Schema") -> ShapeSerializer:
213226
if EVENT_HEADER_TRAIT in schema.traits:
214227
return self._header_serializer
215-
return self._payload_serializer
228+
return self._payload_struct_serializer
216229

217230
def after(self, schema: "Schema") -> None:
218231
pass

python-packages/aws-event-stream/tests/unit/_private/__init__.py

Lines changed: 99 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,27 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
from dataclasses import dataclass
43
import datetime
5-
from typing import Any, Self, ClassVar, Literal
6-
7-
from aws_event_stream.events import EventMessage, Short, Byte, Long
4+
from dataclasses import dataclass
5+
from typing import Any, ClassVar, Literal, Self
86

9-
from smithy_core.serializers import ShapeSerializer
107
from smithy_core.deserializers import ShapeDeserializer
11-
from smithy_core.schemas import Schema
128
from smithy_core.exceptions import SmithyException
13-
from smithy_core.shapes import ShapeID, ShapeType
14-
from smithy_core.traits import Trait
159
from smithy_core.prelude import (
10+
BLOB,
1611
BOOLEAN,
1712
BYTE,
18-
SHORT,
1913
INTEGER,
2014
LONG,
21-
BLOB,
15+
SHORT,
2216
STRING,
2317
TIMESTAMP,
2418
)
19+
from smithy_core.schemas import Schema
20+
from smithy_core.serializers import ShapeSerializer
21+
from smithy_core.shapes import ShapeID, ShapeType
22+
from smithy_core.traits import Trait
2523

24+
from aws_event_stream.events import Byte, EventMessage, Long, Short
2625

2726
EVENT_HEADER_TRAIT = Trait(id=ShapeID("smithy.api#eventHeader"))
2827
EVENT_PAYLOAD_TRAIT = Trait(id=ShapeID("smithy.api#eventPayload"))
@@ -66,6 +65,22 @@
6665
},
6766
)
6867

68+
SCHEMA_BLOB_PAYLOAD_EVENT = Schema.collection(
69+
id=ShapeID("smithy.example#BlobPayloadEvent"),
70+
members={
71+
"header": {
72+
"index": 0,
73+
"target": STRING,
74+
"traits": [EVENT_HEADER_TRAIT, REQUIRED_TRAIT],
75+
},
76+
"payload": {
77+
"index": 1,
78+
"target": BLOB,
79+
"traits": [EVENT_PAYLOAD_TRAIT, REQUIRED_TRAIT],
80+
},
81+
},
82+
)
83+
6984
SCHEMA_ERROR_EVENT = Schema.collection(
7085
id=ShapeID("smithy.example#ErrorEvent"),
7186
members={"message": {"index": 0, "target": STRING, "traits": [REQUIRED_TRAIT]}},
@@ -79,7 +94,8 @@
7994
members={
8095
"message": {"index": 0, "target": SCHEMA_MESSAGE_EVENT},
8196
"payload": {"index": 1, "target": SCHEMA_PAYLOAD_EVENT},
82-
"error": {"index": 2, "target": SCHEMA_ERROR_EVENT},
97+
"blobPayload": {"index": 2, "target": SCHEMA_BLOB_PAYLOAD_EVENT},
98+
"error": {"index": 3, "target": SCHEMA_ERROR_EVENT},
8399
},
84100
)
85101

@@ -273,6 +289,57 @@ def serialize_members(self, serializer: ShapeSerializer):
273289
serializer.write_struct(SCHEMA_EVENT_STREAM.members["payload"], self.value)
274290

275291

292+
@dataclass
293+
class BlobPayloadEvent:
294+
header: str
295+
payload: bytes
296+
297+
def serialize(self, serializer: ShapeSerializer):
298+
with serializer.begin_struct(SCHEMA_BLOB_PAYLOAD_EVENT) as s:
299+
self.serialize_members(s)
300+
301+
def serialize_members(self, serializer: ShapeSerializer) -> None:
302+
serializer.write_string(
303+
SCHEMA_BLOB_PAYLOAD_EVENT.members["header"], self.header
304+
)
305+
serializer.write_blob(
306+
SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"], self.payload
307+
)
308+
309+
@classmethod
310+
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
311+
kwargs: dict[str, Any] = {}
312+
313+
def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
314+
match schema.expect_member_index():
315+
case 0:
316+
kwargs["header"] = de.read_string(
317+
SCHEMA_BLOB_PAYLOAD_EVENT.members["header"]
318+
)
319+
case 1:
320+
kwargs["payload"] = de.read_blob(
321+
SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"]
322+
)
323+
case _:
324+
raise SmithyException(f"Unexpected member schema: {schema}")
325+
326+
deserializer.read_struct(schema=SCHEMA_BLOB_PAYLOAD_EVENT, consumer=_consumer)
327+
return cls(**kwargs)
328+
329+
330+
@dataclass
331+
class EventStreamBlobPayloadEvent:
332+
value: BlobPayloadEvent
333+
334+
def serialize(self, serializer: ShapeSerializer):
335+
serializer.write_struct(SCHEMA_EVENT_STREAM, self)
336+
337+
def serialize_members(self, serializer: ShapeSerializer):
338+
serializer.write_struct(
339+
SCHEMA_EVENT_STREAM.members["blobPayload"], self.value
340+
)
341+
342+
276343
@dataclass
277344
class ErrorEvent:
278345
code: ClassVar[str] = "NoSuchResource"
@@ -326,7 +393,7 @@ def serialize_members(self, serializer: ShapeSerializer):
326393
raise SmithyException("Unknown union variants may not be serialized.")
327394

328395

329-
type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent
396+
type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent
330397

331398

332399
class EventStreamDeserializer:
@@ -350,6 +417,11 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
350417
self._set_result(EventStreamPayloadEvent(PayloadEvent.deserialize(de)))
351418

352419
case 2:
420+
self._set_result(
421+
EventStreamBlobPayloadEvent(BlobPayloadEvent.deserialize(de))
422+
)
423+
424+
case 3:
353425
self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de)))
354426

355427
case _:
@@ -528,7 +600,21 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
528600
"header": "header",
529601
":content-type": "text/plain",
530602
},
531-
payload=b'"payload"',
603+
payload=b"payload",
604+
),
605+
),
606+
(
607+
EventStreamBlobPayloadEvent(
608+
BlobPayloadEvent(header="header", payload=b"\x07beep\x07")
609+
),
610+
EventMessage(
611+
headers={
612+
":message-type": "event",
613+
":event-type": "blobPayload",
614+
"header": "header",
615+
":content-type": "application/octet-stream",
616+
},
617+
payload=b"\x07beep\x07",
532618
),
533619
),
534620
(

0 commit comments

Comments
 (0)