diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py b/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py index 75f7768ab..3b4609496 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/__init__.py @@ -1,5 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from smithy_core.schemas import Schema + +from .traits import EVENT_PAYLOAD_TRAIT INITIAL_REQUEST_EVENT_TYPE = "initial-request" INITIAL_RESPONSE_EVENT_TYPE = "initial-response" + + +def get_payload_member(schema: Schema) -> Schema | None: + for member in schema.members.values(): + if EVENT_PAYLOAD_TRAIT in member.traits: + return member + return None diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index b3af08b83..5bfff8ed4 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -12,13 +12,18 @@ SpecificShapeDeserializer, ) from smithy_core.schemas import Schema +from smithy_core.shapes import ShapeType from smithy_core.utils import expect_type from smithy_event_stream.aio.interfaces import AsyncEventReceiver from ..events import HEADERS_DICT, Event from ..exceptions import EventError, UnmodeledEventError -from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE -from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT +from . import ( + INITIAL_REQUEST_EVENT_TYPE, + INITIAL_RESPONSE_EVENT_TYPE, + get_payload_member, +) +from .traits import EVENT_HEADER_TRAIT INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE) @@ -69,25 +74,24 @@ def read_struct( ) -> None: headers = self._event.message.headers - payload_deserializer = None - if self._event.message.payload: - payload_deserializer = self._payload_codec.create_deserializer( - self._event.message.payload - ) - - message_deserializer = EventMessageDeserializer(headers, payload_deserializer) - match headers.get(":message-type"): case "event": member_name = expect_type(str, headers[":event-type"]) if member_name in INITIAL_MESSAGE_TYPES: # If it's an initial message, skip straight to deserialization. + message_deserializer = self._create_deserializer(schema, headers) message_deserializer.read_struct(schema, consumer) else: - consumer(schema.members[member_name], message_deserializer) + member_schema = schema.members[member_name] + message_deserializer = self._create_deserializer( + member_schema, headers + ) + consumer(member_schema, message_deserializer) case "exception": member_name = expect_type(str, headers[":exception-type"]) - consumer(schema.members[member_name], message_deserializer) + member_schema = schema.members[member_name] + message_deserializer = self._create_deserializer(member_schema, headers) + consumer(member_schema, message_deserializer) case "error": # The `application/vnd.amazon.eventstream` format allows for explicitly # unmodeled exceptions. These exceptions MUST have the `:error-code` @@ -99,13 +103,49 @@ def read_struct( case _: raise EventError(f"Unknown event structure: {self._event}") + def _create_deserializer( + self, schema: Schema, headers: HEADERS_DICT + ) -> ShapeDeserializer: + payload_member = get_payload_member(schema) + payload_deserializer = self._create_payload_deserializer(payload_member) + return EventMessageDeserializer(headers, payload_deserializer, payload_member) + + def _create_payload_deserializer( + self, payload_member: Schema | None + ) -> ShapeDeserializer | None: + if not self._event.message.payload: + return + + if payload_member is not None and payload_member.shape_type in ( + ShapeType.BLOB, + ShapeType.STRING, + ): + return RawPayloadDeserializer(self._event.message.payload) + + return self._payload_codec.create_deserializer(self._event.message.payload) + + +class RawPayloadDeserializer(SpecificShapeDeserializer): + def __init__(self, payload: bytes) -> None: + self._payload = payload + + def read_string(self, schema: Schema) -> str: + return self._payload.decode("utf-8") + + def read_blob(self, schema: Schema) -> bytes: + return self._payload + class EventMessageDeserializer(SpecificShapeDeserializer): def __init__( - self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer | None + self, + headers: HEADERS_DICT, + payload_deserializer: ShapeDeserializer | None, + payload_member: Schema | None, ) -> None: self._headers = headers self._payload_deserializer = payload_deserializer + self._payload_member = payload_member def read_struct( self, @@ -119,17 +159,11 @@ def read_struct( consumer(member_schema, headers_deserializer) if self._payload_deserializer: - if (payload_member := self._get_payload_member(schema)) is not None: - consumer(payload_member, self._payload_deserializer) + if self._payload_member is not None: + consumer(self._payload_member, self._payload_deserializer) else: self._payload_deserializer.read_struct(schema, consumer) - def _get_payload_member(self, schema: "Schema") -> "Schema | None": - for member in schema.members.values(): - if EVENT_PAYLOAD_TRAIT in member.traits: - return member - return None - class EventHeaderDeserializer(SpecificShapeDeserializer): def __init__(self, headers: HEADERS_DICT) -> None: diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py index 2edbbb21e..ae8ddac56 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/serializers.py @@ -23,13 +23,12 @@ from ..events import EventHeaderEncoder, EventMessage from ..exceptions import InvalidHeaderValue -from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE -from .traits import ( - ERROR_TRAIT, - EVENT_HEADER_TRAIT, - EVENT_PAYLOAD_TRAIT, - MEDIA_TYPE_TRAIT, +from . import ( + INITIAL_REQUEST_EVENT_TYPE, + INITIAL_RESPONSE_EVENT_TYPE, + get_payload_member, ) +from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, MEDIA_TYPE_TRAIT _DEFAULT_STRING_CONTENT_TYPE = "text/plain" _DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream" @@ -129,8 +128,10 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: media_type = self._payload_codec.media_type - if (payload_member := self._get_payload_member(schema)) is not None: + if (payload_member := get_payload_member(schema)) is not None: media_type = self._get_payload_media_type(payload_member, media_type) + if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING): + payload_serializer = RawPayloadSerializer(payload) yield EventStreamBindingSerializer(header_serializer, payload_serializer) else: with payload_serializer.begin_struct(schema) as body_serializer: @@ -144,12 +145,6 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: headers_bytes=headers_encoder.get_result(), payload=payload_bytes ) - def _get_payload_member(self, schema: Schema) -> Schema | None: - for member in schema.members.values(): - if EVENT_PAYLOAD_TRAIT in member.traits: - return member - return None - def _get_payload_media_type(self, schema: Schema, default: str) -> str: if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None: return expect_type(str, media_type.value) @@ -200,19 +195,30 @@ def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None: self._encoder.encode_timestamp(schema.expect_member_name(), value) +class RawPayloadSerializer(SpecificShapeSerializer): + def __init__(self, payload: BytesIO) -> None: + self._payload = payload + + def write_string(self, schema: "Schema", value: str) -> None: + self._payload.write(value.encode("utf-8")) + + def write_blob(self, schema: "Schema", value: bytes) -> None: + self._payload.write(value) + + class EventStreamBindingSerializer(InterceptingSerializer): def __init__( self, header_serializer: EventHeaderSerializer, - payload_serializer: ShapeSerializer, + payload_struct_serializer: ShapeSerializer, ) -> None: self._header_serializer = header_serializer - self._payload_serializer = payload_serializer + self._payload_struct_serializer = payload_struct_serializer def before(self, schema: "Schema") -> ShapeSerializer: if EVENT_HEADER_TRAIT in schema.traits: return self._header_serializer - return self._payload_serializer + return self._payload_struct_serializer def after(self, schema: "Schema") -> None: pass diff --git a/python-packages/aws-event-stream/tests/unit/_private/__init__.py b/python-packages/aws-event-stream/tests/unit/_private/__init__.py index ff7858e16..54391d038 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/__init__.py +++ b/python-packages/aws-event-stream/tests/unit/_private/__init__.py @@ -1,28 +1,27 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from dataclasses import dataclass import datetime -from typing import Any, Self, ClassVar, Literal - -from aws_event_stream.events import EventMessage, Short, Byte, Long +from dataclasses import dataclass +from typing import Any, ClassVar, Literal, Self -from smithy_core.serializers import ShapeSerializer from smithy_core.deserializers import ShapeDeserializer -from smithy_core.schemas import Schema from smithy_core.exceptions import SmithyException -from smithy_core.shapes import ShapeID, ShapeType -from smithy_core.traits import Trait from smithy_core.prelude import ( + BLOB, BOOLEAN, BYTE, - SHORT, INTEGER, LONG, - BLOB, + SHORT, STRING, TIMESTAMP, ) +from smithy_core.schemas import Schema +from smithy_core.serializers import ShapeSerializer +from smithy_core.shapes import ShapeID, ShapeType +from smithy_core.traits import Trait +from aws_event_stream.events import Byte, EventMessage, Long, Short EVENT_HEADER_TRAIT = Trait(id=ShapeID("smithy.api#eventHeader")) EVENT_PAYLOAD_TRAIT = Trait(id=ShapeID("smithy.api#eventPayload")) @@ -66,6 +65,22 @@ }, ) +SCHEMA_BLOB_PAYLOAD_EVENT = Schema.collection( + id=ShapeID("smithy.example#BlobPayloadEvent"), + members={ + "header": { + "index": 0, + "target": STRING, + "traits": [EVENT_HEADER_TRAIT, REQUIRED_TRAIT], + }, + "payload": { + "index": 1, + "target": BLOB, + "traits": [EVENT_PAYLOAD_TRAIT, REQUIRED_TRAIT], + }, + }, +) + SCHEMA_ERROR_EVENT = Schema.collection( id=ShapeID("smithy.example#ErrorEvent"), members={"message": {"index": 0, "target": STRING, "traits": [REQUIRED_TRAIT]}}, @@ -79,7 +94,8 @@ members={ "message": {"index": 0, "target": SCHEMA_MESSAGE_EVENT}, "payload": {"index": 1, "target": SCHEMA_PAYLOAD_EVENT}, - "error": {"index": 2, "target": SCHEMA_ERROR_EVENT}, + "blobPayload": {"index": 2, "target": SCHEMA_BLOB_PAYLOAD_EVENT}, + "error": {"index": 3, "target": SCHEMA_ERROR_EVENT}, }, ) @@ -273,6 +289,57 @@ def serialize_members(self, serializer: ShapeSerializer): serializer.write_struct(SCHEMA_EVENT_STREAM.members["payload"], self.value) +@dataclass +class BlobPayloadEvent: + header: str + payload: bytes + + def serialize(self, serializer: ShapeSerializer): + with serializer.begin_struct(SCHEMA_BLOB_PAYLOAD_EVENT) as s: + self.serialize_members(s) + + def serialize_members(self, serializer: ShapeSerializer) -> None: + serializer.write_string( + SCHEMA_BLOB_PAYLOAD_EVENT.members["header"], self.header + ) + serializer.write_blob( + SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"], self.payload + ) + + @classmethod + def deserialize(cls, deserializer: ShapeDeserializer) -> Self: + kwargs: dict[str, Any] = {} + + def _consumer(schema: Schema, de: ShapeDeserializer) -> None: + match schema.expect_member_index(): + case 0: + kwargs["header"] = de.read_string( + SCHEMA_BLOB_PAYLOAD_EVENT.members["header"] + ) + case 1: + kwargs["payload"] = de.read_blob( + SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"] + ) + case _: + raise SmithyException(f"Unexpected member schema: {schema}") + + deserializer.read_struct(schema=SCHEMA_BLOB_PAYLOAD_EVENT, consumer=_consumer) + return cls(**kwargs) + + +@dataclass +class EventStreamBlobPayloadEvent: + value: BlobPayloadEvent + + def serialize(self, serializer: ShapeSerializer): + serializer.write_struct(SCHEMA_EVENT_STREAM, self) + + def serialize_members(self, serializer: ShapeSerializer): + serializer.write_struct( + SCHEMA_EVENT_STREAM.members["blobPayload"], self.value + ) + + @dataclass class ErrorEvent: code: ClassVar[str] = "NoSuchResource" @@ -326,7 +393,7 @@ def serialize_members(self, serializer: ShapeSerializer): raise SmithyException("Unknown union variants may not be serialized.") -type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent +type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent class EventStreamDeserializer: @@ -350,6 +417,11 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None: self._set_result(EventStreamPayloadEvent(PayloadEvent.deserialize(de))) case 2: + self._set_result( + EventStreamBlobPayloadEvent(BlobPayloadEvent.deserialize(de)) + ) + + case 3: self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de))) case _: @@ -528,7 +600,21 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None: "header": "header", ":content-type": "text/plain", }, - payload=b'"payload"', + payload=b"payload", + ), + ), + ( + EventStreamBlobPayloadEvent( + BlobPayloadEvent(header="header", payload=b"\x07beep\x07") + ), + EventMessage( + headers={ + ":message-type": "event", + ":event-type": "blobPayload", + "header": "header", + ":content-type": "application/octet-stream", + }, + payload=b"\x07beep\x07", ), ), (