diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index bd2254119..b3af08b83 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -38,6 +38,9 @@ def __init__( async def receive(self) -> E | None: event = await Event.decode_async(self._source) + if event is None: + return None + deserializer = EventDeserializer( event=event, payload_codec=self._payload_codec, diff --git a/python-packages/aws-event-stream/aws_event_stream/events.py b/python-packages/aws-event-stream/aws_event_stream/events.py index ec77473f0..cdd76d493 100644 --- a/python-packages/aws-event-stream/aws_event_stream/events.py +++ b/python-packages/aws-event-stream/aws_event_stream/events.py @@ -24,6 +24,7 @@ from .exceptions import ( ChecksumMismatch, DuplicateHeader, + InvalidEventBytes, InvalidHeadersLength, InvalidHeaderValue, InvalidHeaderValueLength, @@ -286,27 +287,41 @@ class Event: """ @classmethod - def decode(cls, source: BytesReader) -> Self: + def decode(cls, source: BytesReader) -> Self | None: """Decode an event from a byte stream. :param source: An object to read event bytes from. It must have a `read` method that accepts a number of bytes to read. - :returns: An Event representing the next event on the source. + :returns: An Event representing the next event on the source, or None if no + data can be read from the source. """ prelude_bytes = source.read(8) + if not prelude_bytes: + # If nothing can be read from the source, return None. If bytes are missing + # later, that indicates a problem with the source and therefore will result + # in an exception. + return None + prelude_crc_bytes = source.read(4) - prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + try: + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + total_length, headers_length = unpack("!II", prelude_bytes) + except struct.error as e: + raise InvalidEventBytes() from e - total_length, headers_length = unpack("!II", prelude_bytes) _validate_checksum(prelude_bytes, prelude_crc) prelude = EventPrelude( total_length=total_length, headers_length=headers_length, crc=prelude_crc ) message_bytes = source.read(total_length - _MESSAGE_METADATA_SIZE) - crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0] + try: + crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0] + except struct.error as e: + raise InvalidEventBytes() from e + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) message = EventMessage( @@ -316,27 +331,41 @@ def decode(cls, source: BytesReader) -> Self: return cls(prelude, message, crc) @classmethod - async def decode_async(cls, source: AsyncByteStream) -> Self: + async def decode_async(cls, source: AsyncByteStream) -> Self | None: """Decode an event from an async byte stream. :param source: An object to read event bytes from. It must have a `read` method that accepts a number of bytes to read. - :returns: An Event representing the next event on the source. + :returns: An Event representing the next event on the source, or None if no + data can be read from the source. """ prelude_bytes = await source.read(8) + if not prelude_bytes: + # If nothing can be read from the source, return None. If bytes are missing + # later, that indicates a problem with the source and therefore will result + # in an exception. + return None + prelude_crc_bytes = await source.read(4) - prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + try: + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + total_length, headers_length = unpack("!II", prelude_bytes) + except struct.error as e: + raise InvalidEventBytes() from e - total_length, headers_length = unpack("!II", prelude_bytes) _validate_checksum(prelude_bytes, prelude_crc) prelude = EventPrelude( total_length=total_length, headers_length=headers_length, crc=prelude_crc ) message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE) - crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + try: + crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + except struct.error as e: + raise InvalidEventBytes() from e + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) message = EventMessage( diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py index ec705c611..69293c8b2 100644 --- a/python-packages/aws-event-stream/aws_event_stream/exceptions.py +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -93,6 +93,12 @@ def __init__(self, size: str, value: int): super().__init__(message) +class InvalidEventBytes(EventError): + def __init__(self) -> None: + message = "Invalid event bytes." + super().__init__(message) + + class MissingInitialResponse(EventError): def __init__(self) -> None: super().__init__("Expected an initial response, but none was found.") diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py index 301533905..3e4af5def 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -22,6 +22,7 @@ @pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamDeserializer().deserialize(deserializer) assert result == expected @@ -30,6 +31,7 @@ def test_event_deserializer(expected: DeserializeableShape, given: EventMessage) def test_deserialize_initial_request(): expected, given = INITIAL_REQUEST_CASE source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -38,6 +40,7 @@ def test_deserialize_initial_request(): def test_deserialize_initial_response(): expected, given = INITIAL_RESPONSE_CASE source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -52,6 +55,7 @@ def test_deserialize_unmodeled_error(): } ) source = Event.decode(BytesIO(message.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) with pytest.raises(UnmodeledEventError, match="InternalError"): diff --git a/python-packages/aws-event-stream/tests/unit/test_events.py b/python-packages/aws-event-stream/tests/unit/test_events.py index d6c534bb8..97f97a276 100644 --- a/python-packages/aws-event-stream/tests/unit/test_events.py +++ b/python-packages/aws-event-stream/tests/unit/test_events.py @@ -25,6 +25,7 @@ from aws_event_stream.exceptions import ( ChecksumMismatch, DuplicateHeader, + InvalidEventBytes, InvalidHeadersLength, InvalidHeaderValueLength, InvalidIntegerValue, @@ -381,6 +382,8 @@ ), ) +EMPTY_SOURCE = (b"", None) + # Tuples of encoded messages and their expected decoded output POSITIVE_CASES = [ EMPTY_MESSAGE, # standard @@ -398,6 +401,7 @@ PAYLOAD_ONE_STR_HEADER, # standard ALL_HEADERS_TYPES, # standard ERROR_EVENT_MESSAGE, + EMPTY_SOURCE, ] CORRUPTED_HEADERS_LENGTH = ( @@ -489,47 +493,54 @@ InvalidPayloadLength, ) +TRUNCATED_PRELUDE = (b"\x00", InvalidEventBytes) + +MISSING_PRELUDE_CRC_BYTES = (b"\x00\x00\x00\x16", InvalidEventBytes) + +MISSING_MESSAGE_CRC_BYTES = ( + ( + b"\x00\x00\x00\x10" # total length + b"\x00\x00\x00\x00" # headers length + b"\x05\xc2\x48\xeb" # prelude crc + ), + InvalidEventBytes, +) + # Tuples of encoded messages and their expected exception -NEGATIVE_CASES = [ - CORRUPTED_LENGTH, # standard - CORRUPTED_PAYLOAD, # standard - CORRUPTED_HEADERS, # standard - CORRUPTED_HEADERS_LENGTH, # standard - DUPLICATE_HEADER, - INVALID_HEADERS_LENGTH, - INVALID_HEADER_VALUE_LENGTH, - INVALID_PAYLOAD_LENGTH, -] +NEGATIVE_CASES = { + "corrupted-length": CORRUPTED_LENGTH, # standard + "corrupted-payload": CORRUPTED_PAYLOAD, # standard + "corrupted-headers": CORRUPTED_HEADERS, # standard + "corrupted-headers-length": CORRUPTED_HEADERS_LENGTH, # standard + "duplicate-header": DUPLICATE_HEADER, + "invalid-headers-length": INVALID_HEADERS_LENGTH, + "invalid-header-value-length": INVALID_HEADER_VALUE_LENGTH, + "invalid-payload-length": INVALID_PAYLOAD_LENGTH, + "truncated-prelude": TRUNCATED_PRELUDE, + "missing-prelude-crc-bytes": MISSING_PRELUDE_CRC_BYTES, + "missing-message-crc-bytes": MISSING_MESSAGE_CRC_BYTES, +} @pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) -def test_decode(encoded: bytes, expected: Event): +def test_decode(encoded: bytes, expected: Event | None): assert Event.decode(BytesIO(encoded)) == expected @pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) -async def test_decode_async(encoded: bytes, expected: Event): +async def test_decode_async(encoded: bytes, expected: Event | None): assert await Event.decode_async(AsyncBytesReader(encoded)) == expected -@pytest.mark.parametrize("expected,event", POSITIVE_CASES) +@pytest.mark.parametrize( + "expected,event", [c for c in POSITIVE_CASES if c[1] is not None] +) def test_encode(expected: bytes, event: Event): assert event.message.encode() == expected @pytest.mark.parametrize( - "encoded,expected", - NEGATIVE_CASES, - ids=[ - "corrupted-length", - "corrupted-payload", - "corrupted-headers", - "corrupted-headers-length", - "duplicate-header", - "invalid-headers-length", - "invalid-header-value-length", - "invalid-payload-length", - ], + "encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys() ) def test_negative_cases(encoded: bytes, expected: type[Exception]): with pytest.raises(expected): @@ -537,18 +548,7 @@ def test_negative_cases(encoded: bytes, expected: type[Exception]): @pytest.mark.parametrize( - "encoded,expected", - NEGATIVE_CASES, - ids=[ - "corrupted-length", - "corrupted-payload", - "corrupted-headers", - "corrupted-headers-length", - "duplicate-header", - "invalid-headers-length", - "invalid-header-value-length", - "invalid-payload-length", - ], + "encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys() ) async def test_negative_cases_async(encoded: bytes, expected: type[Exception]): with pytest.raises(expected):