Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from smithy_core.schemas import Schema

from .traits import EVENT_PAYLOAD_TRAIT

INITIAL_REQUEST_EVENT_TYPE = "initial-request"
INITIAL_RESPONSE_EVENT_TYPE = "initial-response"


def get_payload_member(schema: Schema) -> Schema | None:
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return member
return None
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,18 @@
SpecificShapeDeserializer,
)
from smithy_core.schemas import Schema
from smithy_core.shapes import ShapeType
from smithy_core.utils import expect_type
from smithy_event_stream.aio.interfaces import AsyncEventReceiver

from ..events import HEADERS_DICT, Event
from ..exceptions import EventError, UnmodeledEventError
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
from .traits import EVENT_HEADER_TRAIT, EVENT_PAYLOAD_TRAIT
from . import (
INITIAL_REQUEST_EVENT_TYPE,
INITIAL_RESPONSE_EVENT_TYPE,
get_payload_member,
)
from .traits import EVENT_HEADER_TRAIT

INITIAL_MESSAGE_TYPES = (INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE)

Expand Down Expand Up @@ -69,25 +74,24 @@ def read_struct(
) -> None:
headers = self._event.message.headers

payload_deserializer = None
if self._event.message.payload:
payload_deserializer = self._payload_codec.create_deserializer(
self._event.message.payload
)

message_deserializer = EventMessageDeserializer(headers, payload_deserializer)

match headers.get(":message-type"):
case "event":
member_name = expect_type(str, headers[":event-type"])
if member_name in INITIAL_MESSAGE_TYPES:
# If it's an initial message, skip straight to deserialization.
message_deserializer = self._create_deserializer(schema, headers)
message_deserializer.read_struct(schema, consumer)
else:
consumer(schema.members[member_name], message_deserializer)
member_schema = schema.members[member_name]
message_deserializer = self._create_deserializer(
member_schema, headers
)
consumer(member_schema, message_deserializer)
case "exception":
member_name = expect_type(str, headers[":exception-type"])
consumer(schema.members[member_name], message_deserializer)
member_schema = schema.members[member_name]
message_deserializer = self._create_deserializer(member_schema, headers)
consumer(member_schema, message_deserializer)
case "error":
# The `application/vnd.amazon.eventstream` format allows for explicitly
# unmodeled exceptions. These exceptions MUST have the `:error-code`
Expand All @@ -99,13 +103,49 @@ def read_struct(
case _:
raise EventError(f"Unknown event structure: {self._event}")

def _create_deserializer(
self, schema: Schema, headers: HEADERS_DICT
) -> ShapeDeserializer:
payload_member = get_payload_member(schema)
payload_deserializer = self._create_payload_deserializer(payload_member)
return EventMessageDeserializer(headers, payload_deserializer, payload_member)

def _create_payload_deserializer(
self, payload_member: Schema | None
) -> ShapeDeserializer | None:
if not self._event.message.payload:
return

if payload_member is not None and payload_member.shape_type in (
ShapeType.BLOB,
ShapeType.STRING,
):
return RawPayloadDeserializer(self._event.message.payload)

return self._payload_codec.create_deserializer(self._event.message.payload)


class RawPayloadDeserializer(SpecificShapeDeserializer):
def __init__(self, payload: bytes) -> None:
self._payload = payload

def read_string(self, schema: Schema) -> str:
return self._payload.decode("utf-8")

def read_blob(self, schema: Schema) -> bytes:
return self._payload


class EventMessageDeserializer(SpecificShapeDeserializer):
def __init__(
self, headers: HEADERS_DICT, payload_deserializer: ShapeDeserializer | None
self,
headers: HEADERS_DICT,
payload_deserializer: ShapeDeserializer | None,
payload_member: Schema | None,
) -> None:
self._headers = headers
self._payload_deserializer = payload_deserializer
self._payload_member = payload_member

def read_struct(
self,
Expand All @@ -119,17 +159,11 @@ def read_struct(
consumer(member_schema, headers_deserializer)

if self._payload_deserializer:
if (payload_member := self._get_payload_member(schema)) is not None:
consumer(payload_member, self._payload_deserializer)
if self._payload_member is not None:
consumer(self._payload_member, self._payload_deserializer)
else:
self._payload_deserializer.read_struct(schema, consumer)

def _get_payload_member(self, schema: "Schema") -> "Schema | None":
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return member
return None


class EventHeaderDeserializer(SpecificShapeDeserializer):
def __init__(self, headers: HEADERS_DICT) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@

from ..events import EventHeaderEncoder, EventMessage
from ..exceptions import InvalidHeaderValue
from . import INITIAL_REQUEST_EVENT_TYPE, INITIAL_RESPONSE_EVENT_TYPE
from .traits import (
ERROR_TRAIT,
EVENT_HEADER_TRAIT,
EVENT_PAYLOAD_TRAIT,
MEDIA_TYPE_TRAIT,
from . import (
INITIAL_REQUEST_EVENT_TYPE,
INITIAL_RESPONSE_EVENT_TYPE,
get_payload_member,
)
from .traits import ERROR_TRAIT, EVENT_HEADER_TRAIT, MEDIA_TYPE_TRAIT

_DEFAULT_STRING_CONTENT_TYPE = "text/plain"
_DEFAULT_BLOB_CONTENT_TYPE = "application/octet-stream"
Expand Down Expand Up @@ -129,8 +128,10 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:

media_type = self._payload_codec.media_type

if (payload_member := self._get_payload_member(schema)) is not None:
if (payload_member := get_payload_member(schema)) is not None:
media_type = self._get_payload_media_type(payload_member, media_type)
if payload_member.shape_type in (ShapeType.BLOB, ShapeType.STRING):
payload_serializer = RawPayloadSerializer(payload)
yield EventStreamBindingSerializer(header_serializer, payload_serializer)
else:
with payload_serializer.begin_struct(schema) as body_serializer:
Expand All @@ -144,12 +145,6 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
headers_bytes=headers_encoder.get_result(), payload=payload_bytes
)

def _get_payload_member(self, schema: Schema) -> Schema | None:
for member in schema.members.values():
if EVENT_PAYLOAD_TRAIT in member.traits:
return member
return None

def _get_payload_media_type(self, schema: Schema, default: str) -> str:
if (media_type := schema.traits.get(MEDIA_TYPE_TRAIT)) is not None:
return expect_type(str, media_type.value)
Expand Down Expand Up @@ -200,19 +195,30 @@ def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None:
self._encoder.encode_timestamp(schema.expect_member_name(), value)


class RawPayloadSerializer(SpecificShapeSerializer):
def __init__(self, payload: BytesIO) -> None:
self._payload = payload

def write_string(self, schema: "Schema", value: str) -> None:
self._payload.write(value.encode("utf-8"))

def write_blob(self, schema: "Schema", value: bytes) -> None:
self._payload.write(value)


class EventStreamBindingSerializer(InterceptingSerializer):
def __init__(
self,
header_serializer: EventHeaderSerializer,
payload_serializer: ShapeSerializer,
payload_struct_serializer: ShapeSerializer,
) -> None:
self._header_serializer = header_serializer
self._payload_serializer = payload_serializer
self._payload_struct_serializer = payload_struct_serializer

def before(self, schema: "Schema") -> ShapeSerializer:
if EVENT_HEADER_TRAIT in schema.traits:
return self._header_serializer
return self._payload_serializer
return self._payload_struct_serializer

def after(self, schema: "Schema") -> None:
pass
112 changes: 99 additions & 13 deletions python-packages/aws-event-stream/tests/unit/_private/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,27 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import datetime
from typing import Any, Self, ClassVar, Literal

from aws_event_stream.events import EventMessage, Short, Byte, Long
from dataclasses import dataclass
from typing import Any, ClassVar, Literal, Self

from smithy_core.serializers import ShapeSerializer
from smithy_core.deserializers import ShapeDeserializer
from smithy_core.schemas import Schema
from smithy_core.exceptions import SmithyException
from smithy_core.shapes import ShapeID, ShapeType
from smithy_core.traits import Trait
from smithy_core.prelude import (
BLOB,
BOOLEAN,
BYTE,
SHORT,
INTEGER,
LONG,
BLOB,
SHORT,
STRING,
TIMESTAMP,
)
from smithy_core.schemas import Schema
from smithy_core.serializers import ShapeSerializer
from smithy_core.shapes import ShapeID, ShapeType
from smithy_core.traits import Trait

from aws_event_stream.events import Byte, EventMessage, Long, Short

EVENT_HEADER_TRAIT = Trait(id=ShapeID("smithy.api#eventHeader"))
EVENT_PAYLOAD_TRAIT = Trait(id=ShapeID("smithy.api#eventPayload"))
Expand Down Expand Up @@ -66,6 +65,22 @@
},
)

SCHEMA_BLOB_PAYLOAD_EVENT = Schema.collection(
id=ShapeID("smithy.example#BlobPayloadEvent"),
members={
"header": {
"index": 0,
"target": STRING,
"traits": [EVENT_HEADER_TRAIT, REQUIRED_TRAIT],
},
"payload": {
"index": 1,
"target": BLOB,
"traits": [EVENT_PAYLOAD_TRAIT, REQUIRED_TRAIT],
},
},
)

SCHEMA_ERROR_EVENT = Schema.collection(
id=ShapeID("smithy.example#ErrorEvent"),
members={"message": {"index": 0, "target": STRING, "traits": [REQUIRED_TRAIT]}},
Expand All @@ -79,7 +94,8 @@
members={
"message": {"index": 0, "target": SCHEMA_MESSAGE_EVENT},
"payload": {"index": 1, "target": SCHEMA_PAYLOAD_EVENT},
"error": {"index": 2, "target": SCHEMA_ERROR_EVENT},
"blobPayload": {"index": 2, "target": SCHEMA_BLOB_PAYLOAD_EVENT},
"error": {"index": 3, "target": SCHEMA_ERROR_EVENT},
},
)

Expand Down Expand Up @@ -273,6 +289,57 @@ def serialize_members(self, serializer: ShapeSerializer):
serializer.write_struct(SCHEMA_EVENT_STREAM.members["payload"], self.value)


@dataclass
class BlobPayloadEvent:
header: str
payload: bytes

def serialize(self, serializer: ShapeSerializer):
with serializer.begin_struct(SCHEMA_BLOB_PAYLOAD_EVENT) as s:
self.serialize_members(s)

def serialize_members(self, serializer: ShapeSerializer) -> None:
serializer.write_string(
SCHEMA_BLOB_PAYLOAD_EVENT.members["header"], self.header
)
serializer.write_blob(
SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"], self.payload
)

@classmethod
def deserialize(cls, deserializer: ShapeDeserializer) -> Self:
kwargs: dict[str, Any] = {}

def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
match schema.expect_member_index():
case 0:
kwargs["header"] = de.read_string(
SCHEMA_BLOB_PAYLOAD_EVENT.members["header"]
)
case 1:
kwargs["payload"] = de.read_blob(
SCHEMA_BLOB_PAYLOAD_EVENT.members["payload"]
)
case _:
raise SmithyException(f"Unexpected member schema: {schema}")

deserializer.read_struct(schema=SCHEMA_BLOB_PAYLOAD_EVENT, consumer=_consumer)
return cls(**kwargs)


@dataclass
class EventStreamBlobPayloadEvent:
value: BlobPayloadEvent

def serialize(self, serializer: ShapeSerializer):
serializer.write_struct(SCHEMA_EVENT_STREAM, self)

def serialize_members(self, serializer: ShapeSerializer):
serializer.write_struct(
SCHEMA_EVENT_STREAM.members["blobPayload"], self.value
)


@dataclass
class ErrorEvent:
code: ClassVar[str] = "NoSuchResource"
Expand Down Expand Up @@ -326,7 +393,7 @@ def serialize_members(self, serializer: ShapeSerializer):
raise SmithyException("Unknown union variants may not be serialized.")


type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent
type EventStream = EventStreamMessageEvent | EventStreamPayloadEvent | EventStreamBlobPayloadEvent | EventStreamErrorEvent | EventStreamUnknownEvent


class EventStreamDeserializer:
Expand All @@ -350,6 +417,11 @@ def _consumer(self, schema: Schema, de: ShapeDeserializer) -> None:
self._set_result(EventStreamPayloadEvent(PayloadEvent.deserialize(de)))

case 2:
self._set_result(
EventStreamBlobPayloadEvent(BlobPayloadEvent.deserialize(de))
)

case 3:
self._set_result(EventStreamErrorEvent(ErrorEvent.deserialize(de)))

case _:
Expand Down Expand Up @@ -528,7 +600,21 @@ def _consumer(schema: Schema, de: ShapeDeserializer) -> None:
"header": "header",
":content-type": "text/plain",
},
payload=b'"payload"',
payload=b"payload",
),
),
(
EventStreamBlobPayloadEvent(
BlobPayloadEvent(header="header", payload=b"\x07beep\x07")
),
EventMessage(
headers={
":message-type": "event",
":event-type": "blobPayload",
"header": "header",
":content-type": "application/octet-stream",
},
payload=b"\x07beep\x07",
),
),
(
Expand Down
Loading