Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from smithy_core.shapes import ShapeType
from smithy_event_stream.aio.interfaces import AsyncEventPublisher

from ..events import EventHeaderEncoder, EventMessage
from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long
from ..exceptions import InvalidHeaderValue
from . import (
INITIAL_REQUEST_EVENT_TYPE,
Expand Down Expand Up @@ -100,30 +100,27 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:
finally:
return

headers_encoder = EventHeaderEncoder()
headers: dict[str, HEADER_VALUE] = {}

if ErrorTrait in schema:
headers_encoder.encode_string(":message-type", "exception")
headers_encoder.encode_string(
":exception-type", schema.expect_member_name()
)
headers[":message-type"] = "exception"
headers[":exception-type"] = schema.expect_member_name()
else:
headers_encoder.encode_string(":message-type", "event")
headers[":message-type"] = "event"
if schema.member_name is None:
# If there's no member name, that must mean that the structure is
# either an input or output structure, and so this represents the
# initial message.
headers_encoder.encode_string(
":event-type", self._initial_message_event_type
)
headers[":event-type"] = self._initial_message_event_type
else:
headers_encoder.encode_string(":event-type", schema.member_name)
headers[":event-type"] = schema.member_name

payload = BytesIO()
payload_serializer: ShapeSerializer = self._payload_codec.create_serializer(
payload
)
header_serializer = EventHeaderSerializer(headers_encoder)

header_serializer = EventHeaderSerializer(headers)

media_type = self._payload_codec.media_type

Expand All @@ -138,11 +135,9 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]:

payload_bytes = payload.getvalue()
if payload_bytes:
headers_encoder.encode_string(":content-type", media_type)
headers[":content-type"] = media_type

self._result = EventMessage(
headers_bytes=headers_encoder.get_result(), payload=payload_bytes
)
self._result = EventMessage(headers=headers, payload=payload_bytes)

def _get_payload_media_type(self, schema: Schema, default: str) -> str:
if (media_type := schema.get_trait(MediaTypeTrait)) is not None:
Expand All @@ -158,8 +153,8 @@ def _get_payload_media_type(self, schema: Schema, default: str) -> str:


class EventHeaderSerializer(SpecificShapeSerializer):
def __init__(self, encoder: EventHeaderEncoder) -> None:
self._encoder = encoder
def __init__(self, headers: dict[str, HEADER_VALUE]) -> None:
self._headers = headers

def _invalid_state(
self, schema: "Schema | None" = None, message: str | None = None
Expand All @@ -169,28 +164,28 @@ def _invalid_state(
raise InvalidHeaderValue(message)

def write_boolean(self, schema: "Schema", value: bool) -> None:
self._encoder.encode_boolean(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = value

def write_byte(self, schema: "Schema", value: int) -> None:
self._encoder.encode_byte(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = Byte(value)

def write_short(self, schema: "Schema", value: int) -> None:
self._encoder.encode_short(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = Short(value)

def write_integer(self, schema: "Schema", value: int) -> None:
self._encoder.encode_integer(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = value

def write_long(self, schema: "Schema", value: int) -> None:
self._encoder.encode_long(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = Long(value)

def write_string(self, schema: "Schema", value: str) -> None:
self._encoder.encode_string(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = value

def write_blob(self, schema: "Schema", value: bytes) -> None:
self._encoder.encode_blob(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = value

def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None:
self._encoder.encode_timestamp(schema.expect_member_name(), value)
self._headers[schema.expect_member_name()] = value


class RawPayloadSerializer(SpecificShapeSerializer):
Expand Down
94 changes: 26 additions & 68 deletions packages/aws-event-stream/src/aws_event_stream/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import uuid
from binascii import crc32
from collections.abc import Callable, Iterator, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field
from io import BytesIO
from struct import pack, unpack
from types import MappingProxyType
Expand Down Expand Up @@ -147,6 +147,7 @@ def __post_init__(self):
raise InvalidPayloadLength(payload_length)


@dataclass(kw_only=True, eq=False)
class EventMessage:
"""A message that may be sent over an event stream.

Expand Down Expand Up @@ -186,76 +187,31 @@ class EventMessage:
message.
"""

def __init__(
self,
*,
headers: HEADERS_DICT | None = None,
headers_bytes: bytes | None = None,
payload: bytes = b"",
) -> None:
"""Initialize an EventMessage.

:param headers: The headers present in the event message. If this parameter is
unspecified, the default value will be the decoded value of the
`headers_bytes` parameter.

Sized integer values may be indicated for the purpose of serialization
using the `Byte`, `Short`, or `Long` types. int values of unspecified size
will be assumed to be 32-bit.

:param headers_bytes: The serialized bytes of the headers present in the event
message.

:param payload: The serialized bytes of the message payload.
"""
self._payload = payload
self._headers_bytes = headers_bytes

if len(payload) > MAX_PAYLOAD_LENGTH:
raise InvalidPayloadLength(len(payload))

if headers_bytes is None:
if headers is None:
headers = {}
elif headers is None:
headers = EventHeaderDecoder(headers_bytes).decode_headers()
headers: HEADERS_DICT = field(default_factory=dict)
"""The headers present in the event message.

self._headers = headers

@property
def payload(self) -> bytes:
"""The serialized bytes of the message payload.

These bytes may be in any format or media type. The `:content-type` header, if
present, indicates the media type.
"""
return self._payload
Sized integer values may be indicated for the purpose of serialization
using the `Byte`, `Short`, or `Long` types. int values of unspecified size
will be assumed to be 32-bit.
"""

@property
def headers(self) -> HEADERS_DICT:
"""The headers of the event message.
payload: bytes = b""
"""The serialized bytes of the message payload."""

Headers prefixed with `:` contain metadata by convention.
"""
return self._headers
def __post_init__(
self,
) -> None:
if len(self.payload) > MAX_PAYLOAD_LENGTH:
raise InvalidPayloadLength(len(self.payload))

def _get_headers_bytes(self) -> bytes:
if self._headers_bytes is None:
encoder = EventHeaderEncoder()
encoder.encode_headers(self._headers)
self._headers_bytes = encoder.get_result()

return self._headers_bytes
encoder = EventHeaderEncoder()
encoder.encode_headers(self.headers)
return encoder.get_result()

def encode(self) -> bytes:
return _EventEncoder().encode_bytes(
headers=self._get_headers_bytes(), payload=self._payload
)

def __repr__(self) -> str:
return (
f"EventMessage(payload={self._payload!r}, headers={self.headers!r}, "
f"headers_bytes={self._get_headers_bytes()!r})"
headers=self._get_headers_bytes(), payload=self.payload
)

def __eq__(self, other: object) -> bool:
Expand Down Expand Up @@ -325,8 +281,9 @@ def decode(cls, source: BytesReader) -> Self | None:

_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)

headers_bytes = message_bytes[: prelude.headers_length]
message = EventMessage(
headers_bytes=message_bytes[: prelude.headers_length],
headers=EventHeaderDecoder(headers_bytes).decode_headers(),
payload=message_bytes[prelude.headers_length :],
)
return cls(prelude, message, crc)
Expand Down Expand Up @@ -369,8 +326,9 @@ async def decode_async(cls, source: AsyncByteStream) -> Self | None:

_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)

headers_bytes = message_bytes[: prelude.headers_length]
message = EventMessage(
headers_bytes=message_bytes[: prelude.headers_length],
headers=EventHeaderDecoder(headers_bytes).decode_headers(),
payload=message_bytes[prelude.headers_length :],
)
return cls(prelude, message, crc)
Expand Down Expand Up @@ -647,7 +605,7 @@ def unpack_int8(data: BytesLike):
:returns: A tuple containing the (parsed integer value, bytes consumed)
"""
value = unpack(_DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0]
return value, 1
return Byte(value), 1

@staticmethod
def unpack_int16(data: BytesLike) -> tuple[int, int]:
Expand All @@ -657,7 +615,7 @@ def unpack_int16(data: BytesLike) -> tuple[int, int]:
:returns: A tuple containing the (parsed integer value, bytes consumed)
"""
value = unpack(_DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0]
return value, 2
return Short(value), 2

@staticmethod
def unpack_int32(data: BytesLike) -> tuple[int, int]:
Expand All @@ -677,7 +635,7 @@ def unpack_int64(data: BytesLike) -> tuple[int, int]:
:returns: A tuple containing the (parsed integer value, bytes consumed)
"""
value = unpack(_DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0]
return value, 8
return Long(value), 8

@staticmethod
def unpack_byte_array(
Expand Down
27 changes: 0 additions & 27 deletions packages/aws-event-stream/tests/unit/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,33 +580,6 @@ def test_event_message_rejects_long_header_value():
EventMessage(headers=headers).encode()


def test_event_message_rejects_long_headers():
# 5 of these is more than enough to overcome the header size limit.
long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1)
headers = {
"1": long_value,
"2": long_value,
"3": long_value,
"4": long_value,
"5": long_value,
}
with pytest.raises(InvalidHeadersLength):
EventMessage(headers=headers).encode()

# These are correctly encoded, and individually valid, but collectively too long.
long_headers = b""
for i in range(5):
long_headers += b"\x01" + str(i).encode("utf-8") + b"\x06\x7f\xfe" + long_value

with pytest.raises(InvalidHeadersLength):
EventMessage(headers_bytes=long_headers)


def test_event_message_decodes_headers():
message = EventMessage(headers_bytes=b"\x04true\x00")
assert message.headers == {"true": True}


def test_event_encoder_rejects_long_headers():
long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1)
long_headers = b""
Expand Down