diff --git a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py index 6ea9c1e22..32bb4935f 100644 --- a/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py +++ b/packages/aws-event-stream/src/aws_event_stream/_private/serializers.py @@ -20,7 +20,7 @@ from smithy_core.shapes import ShapeType from smithy_event_stream.aio.interfaces import AsyncEventPublisher -from ..events import EventHeaderEncoder, EventMessage +from ..events import EventMessage, HEADER_VALUE, Short, Byte, Long from ..exceptions import InvalidHeaderValue from . import ( INITIAL_REQUEST_EVENT_TYPE, @@ -100,30 +100,27 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: finally: return - headers_encoder = EventHeaderEncoder() + headers: dict[str, HEADER_VALUE] = {} if ErrorTrait in schema: - headers_encoder.encode_string(":message-type", "exception") - headers_encoder.encode_string( - ":exception-type", schema.expect_member_name() - ) + headers[":message-type"] = "exception" + headers[":exception-type"] = schema.expect_member_name() else: - headers_encoder.encode_string(":message-type", "event") + headers[":message-type"] = "event" if schema.member_name is None: # If there's no member name, that must mean that the structure is # either an input or output structure, and so this represents the # initial message. - headers_encoder.encode_string( - ":event-type", self._initial_message_event_type - ) + headers[":event-type"] = self._initial_message_event_type else: - headers_encoder.encode_string(":event-type", schema.member_name) + headers[":event-type"] = schema.member_name payload = BytesIO() payload_serializer: ShapeSerializer = self._payload_codec.create_serializer( payload ) - header_serializer = EventHeaderSerializer(headers_encoder) + + header_serializer = EventHeaderSerializer(headers) media_type = self._payload_codec.media_type @@ -138,11 +135,9 @@ def begin_struct(self, schema: "Schema") -> Iterator[ShapeSerializer]: payload_bytes = payload.getvalue() if payload_bytes: - headers_encoder.encode_string(":content-type", media_type) + headers[":content-type"] = media_type - self._result = EventMessage( - headers_bytes=headers_encoder.get_result(), payload=payload_bytes - ) + self._result = EventMessage(headers=headers, payload=payload_bytes) def _get_payload_media_type(self, schema: Schema, default: str) -> str: if (media_type := schema.get_trait(MediaTypeTrait)) is not None: @@ -158,8 +153,8 @@ def _get_payload_media_type(self, schema: Schema, default: str) -> str: class EventHeaderSerializer(SpecificShapeSerializer): - def __init__(self, encoder: EventHeaderEncoder) -> None: - self._encoder = encoder + def __init__(self, headers: dict[str, HEADER_VALUE]) -> None: + self._headers = headers def _invalid_state( self, schema: "Schema | None" = None, message: str | None = None @@ -169,28 +164,28 @@ def _invalid_state( raise InvalidHeaderValue(message) def write_boolean(self, schema: "Schema", value: bool) -> None: - self._encoder.encode_boolean(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = value def write_byte(self, schema: "Schema", value: int) -> None: - self._encoder.encode_byte(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = Byte(value) def write_short(self, schema: "Schema", value: int) -> None: - self._encoder.encode_short(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = Short(value) def write_integer(self, schema: "Schema", value: int) -> None: - self._encoder.encode_integer(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = value def write_long(self, schema: "Schema", value: int) -> None: - self._encoder.encode_long(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = Long(value) def write_string(self, schema: "Schema", value: str) -> None: - self._encoder.encode_string(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = value def write_blob(self, schema: "Schema", value: bytes) -> None: - self._encoder.encode_blob(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = value def write_timestamp(self, schema: "Schema", value: datetime.datetime) -> None: - self._encoder.encode_timestamp(schema.expect_member_name(), value) + self._headers[schema.expect_member_name()] = value class RawPayloadSerializer(SpecificShapeSerializer): diff --git a/packages/aws-event-stream/src/aws_event_stream/events.py b/packages/aws-event-stream/src/aws_event_stream/events.py index 8d52bde65..50e8b1598 100644 --- a/packages/aws-event-stream/src/aws_event_stream/events.py +++ b/packages/aws-event-stream/src/aws_event_stream/events.py @@ -12,7 +12,7 @@ import uuid from binascii import crc32 from collections.abc import Callable, Iterator, Mapping -from dataclasses import dataclass +from dataclasses import dataclass, field from io import BytesIO from struct import pack, unpack from types import MappingProxyType @@ -147,6 +147,7 @@ def __post_init__(self): raise InvalidPayloadLength(payload_length) +@dataclass(kw_only=True, eq=False) class EventMessage: """A message that may be sent over an event stream. @@ -186,76 +187,31 @@ class EventMessage: message. """ - def __init__( - self, - *, - headers: HEADERS_DICT | None = None, - headers_bytes: bytes | None = None, - payload: bytes = b"", - ) -> None: - """Initialize an EventMessage. - - :param headers: The headers present in the event message. If this parameter is - unspecified, the default value will be the decoded value of the - `headers_bytes` parameter. - - Sized integer values may be indicated for the purpose of serialization - using the `Byte`, `Short`, or `Long` types. int values of unspecified size - will be assumed to be 32-bit. - - :param headers_bytes: The serialized bytes of the headers present in the event - message. - - :param payload: The serialized bytes of the message payload. - """ - self._payload = payload - self._headers_bytes = headers_bytes - - if len(payload) > MAX_PAYLOAD_LENGTH: - raise InvalidPayloadLength(len(payload)) - - if headers_bytes is None: - if headers is None: - headers = {} - elif headers is None: - headers = EventHeaderDecoder(headers_bytes).decode_headers() + headers: HEADERS_DICT = field(default_factory=dict) + """The headers present in the event message. - self._headers = headers - - @property - def payload(self) -> bytes: - """The serialized bytes of the message payload. - - These bytes may be in any format or media type. The `:content-type` header, if - present, indicates the media type. - """ - return self._payload + Sized integer values may be indicated for the purpose of serialization + using the `Byte`, `Short`, or `Long` types. int values of unspecified size + will be assumed to be 32-bit. + """ - @property - def headers(self) -> HEADERS_DICT: - """The headers of the event message. + payload: bytes = b"" + """The serialized bytes of the message payload.""" - Headers prefixed with `:` contain metadata by convention. - """ - return self._headers + def __post_init__( + self, + ) -> None: + if len(self.payload) > MAX_PAYLOAD_LENGTH: + raise InvalidPayloadLength(len(self.payload)) def _get_headers_bytes(self) -> bytes: - if self._headers_bytes is None: - encoder = EventHeaderEncoder() - encoder.encode_headers(self._headers) - self._headers_bytes = encoder.get_result() - - return self._headers_bytes + encoder = EventHeaderEncoder() + encoder.encode_headers(self.headers) + return encoder.get_result() def encode(self) -> bytes: return _EventEncoder().encode_bytes( - headers=self._get_headers_bytes(), payload=self._payload - ) - - def __repr__(self) -> str: - return ( - f"EventMessage(payload={self._payload!r}, headers={self.headers!r}, " - f"headers_bytes={self._get_headers_bytes()!r})" + headers=self._get_headers_bytes(), payload=self.payload ) def __eq__(self, other: object) -> bool: @@ -325,8 +281,9 @@ def decode(cls, source: BytesReader) -> Self | None: _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) + headers_bytes = message_bytes[: prelude.headers_length] message = EventMessage( - headers_bytes=message_bytes[: prelude.headers_length], + headers=EventHeaderDecoder(headers_bytes).decode_headers(), payload=message_bytes[prelude.headers_length :], ) return cls(prelude, message, crc) @@ -369,8 +326,9 @@ async def decode_async(cls, source: AsyncByteStream) -> Self | None: _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) + headers_bytes = message_bytes[: prelude.headers_length] message = EventMessage( - headers_bytes=message_bytes[: prelude.headers_length], + headers=EventHeaderDecoder(headers_bytes).decode_headers(), payload=message_bytes[prelude.headers_length :], ) return cls(prelude, message, crc) @@ -647,7 +605,7 @@ def unpack_int8(data: BytesLike): :returns: A tuple containing the (parsed integer value, bytes consumed) """ value = unpack(_DecodeUtils.INT8_BYTE_FORMAT, data[:1])[0] - return value, 1 + return Byte(value), 1 @staticmethod def unpack_int16(data: BytesLike) -> tuple[int, int]: @@ -657,7 +615,7 @@ def unpack_int16(data: BytesLike) -> tuple[int, int]: :returns: A tuple containing the (parsed integer value, bytes consumed) """ value = unpack(_DecodeUtils.INT16_BYTE_FORMAT, data[:2])[0] - return value, 2 + return Short(value), 2 @staticmethod def unpack_int32(data: BytesLike) -> tuple[int, int]: @@ -677,7 +635,7 @@ def unpack_int64(data: BytesLike) -> tuple[int, int]: :returns: A tuple containing the (parsed integer value, bytes consumed) """ value = unpack(_DecodeUtils.INT64_BYTE_FORMAT, data[:8])[0] - return value, 8 + return Long(value), 8 @staticmethod def unpack_byte_array( diff --git a/packages/aws-event-stream/tests/unit/test_events.py b/packages/aws-event-stream/tests/unit/test_events.py index 5223676a7..960dfef76 100644 --- a/packages/aws-event-stream/tests/unit/test_events.py +++ b/packages/aws-event-stream/tests/unit/test_events.py @@ -580,33 +580,6 @@ def test_event_message_rejects_long_header_value(): EventMessage(headers=headers).encode() -def test_event_message_rejects_long_headers(): - # 5 of these is more than enough to overcome the header size limit. - long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1) - headers = { - "1": long_value, - "2": long_value, - "3": long_value, - "4": long_value, - "5": long_value, - } - with pytest.raises(InvalidHeadersLength): - EventMessage(headers=headers).encode() - - # These are correctly encoded, and individually valid, but collectively too long. - long_headers = b"" - for i in range(5): - long_headers += b"\x01" + str(i).encode("utf-8") + b"\x06\x7f\xfe" + long_value - - with pytest.raises(InvalidHeadersLength): - EventMessage(headers_bytes=long_headers) - - -def test_event_message_decodes_headers(): - message = EventMessage(headers_bytes=b"\x04true\x00") - assert message.headers == {"true": True} - - def test_event_encoder_rejects_long_headers(): long_value = b"0" * (MAX_HEADER_VALUE_BYTE_LENGTH - 1) long_headers = b""