From d700b2457dfe7bedab25f98d5b64305e3b64f19f Mon Sep 17 00:00:00 2001 From: JordonPhillips Date: Tue, 26 Nov 2024 14:45:52 +0100 Subject: [PATCH] Handle empty and incomplete event bytes This adds in explicit handling for both emtpy and incomplete event bytes. If nothing is able to be read from the source, event decoders will return None. If there are bytes there, but they're truncated, then an explicit error is thrown that wraps what would otherwise be a `struct.error`. This is only applied for truncations that would not already be caught by checksum validation. --- .../_private/deserializers.py | 3 + .../aws_event_stream/events.py | 49 +++++++++--- .../aws_event_stream/exceptions.py | 6 ++ .../tests/unit/_private/test_deserializers.py | 4 + .../tests/unit/test_events.py | 74 +++++++++---------- 5 files changed, 89 insertions(+), 47 deletions(-) diff --git a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py index bd2254119..b3af08b83 100644 --- a/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py +++ b/python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py @@ -38,6 +38,9 @@ def __init__( async def receive(self) -> E | None: event = await Event.decode_async(self._source) + if event is None: + return None + deserializer = EventDeserializer( event=event, payload_codec=self._payload_codec, diff --git a/python-packages/aws-event-stream/aws_event_stream/events.py b/python-packages/aws-event-stream/aws_event_stream/events.py index ec77473f0..cdd76d493 100644 --- a/python-packages/aws-event-stream/aws_event_stream/events.py +++ b/python-packages/aws-event-stream/aws_event_stream/events.py @@ -24,6 +24,7 @@ from .exceptions import ( ChecksumMismatch, DuplicateHeader, + InvalidEventBytes, InvalidHeadersLength, InvalidHeaderValue, InvalidHeaderValueLength, @@ -286,27 +287,41 @@ class Event: """ @classmethod - def decode(cls, source: BytesReader) -> Self: + def decode(cls, source: BytesReader) -> Self | None: """Decode an event from a byte stream. :param source: An object to read event bytes from. It must have a `read` method that accepts a number of bytes to read. - :returns: An Event representing the next event on the source. + :returns: An Event representing the next event on the source, or None if no + data can be read from the source. """ prelude_bytes = source.read(8) + if not prelude_bytes: + # If nothing can be read from the source, return None. If bytes are missing + # later, that indicates a problem with the source and therefore will result + # in an exception. + return None + prelude_crc_bytes = source.read(4) - prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + try: + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + total_length, headers_length = unpack("!II", prelude_bytes) + except struct.error as e: + raise InvalidEventBytes() from e - total_length, headers_length = unpack("!II", prelude_bytes) _validate_checksum(prelude_bytes, prelude_crc) prelude = EventPrelude( total_length=total_length, headers_length=headers_length, crc=prelude_crc ) message_bytes = source.read(total_length - _MESSAGE_METADATA_SIZE) - crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0] + try: + crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0] + except struct.error as e: + raise InvalidEventBytes() from e + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) message = EventMessage( @@ -316,27 +331,41 @@ def decode(cls, source: BytesReader) -> Self: return cls(prelude, message, crc) @classmethod - async def decode_async(cls, source: AsyncByteStream) -> Self: + async def decode_async(cls, source: AsyncByteStream) -> Self | None: """Decode an event from an async byte stream. :param source: An object to read event bytes from. It must have a `read` method that accepts a number of bytes to read. - :returns: An Event representing the next event on the source. + :returns: An Event representing the next event on the source, or None if no + data can be read from the source. """ prelude_bytes = await source.read(8) + if not prelude_bytes: + # If nothing can be read from the source, return None. If bytes are missing + # later, that indicates a problem with the source and therefore will result + # in an exception. + return None + prelude_crc_bytes = await source.read(4) - prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + try: + prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0] + total_length, headers_length = unpack("!II", prelude_bytes) + except struct.error as e: + raise InvalidEventBytes() from e - total_length, headers_length = unpack("!II", prelude_bytes) _validate_checksum(prelude_bytes, prelude_crc) prelude = EventPrelude( total_length=total_length, headers_length=headers_length, crc=prelude_crc ) message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE) - crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + try: + crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0] + except struct.error as e: + raise InvalidEventBytes() from e + _validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc) message = EventMessage( diff --git a/python-packages/aws-event-stream/aws_event_stream/exceptions.py b/python-packages/aws-event-stream/aws_event_stream/exceptions.py index ec705c611..69293c8b2 100644 --- a/python-packages/aws-event-stream/aws_event_stream/exceptions.py +++ b/python-packages/aws-event-stream/aws_event_stream/exceptions.py @@ -93,6 +93,12 @@ def __init__(self, size: str, value: int): super().__init__(message) +class InvalidEventBytes(EventError): + def __init__(self) -> None: + message = "Invalid event bytes." + super().__init__(message) + + class MissingInitialResponse(EventError): def __init__(self) -> None: super().__init__("Expected an initial response, but none was found.") diff --git a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py index 301533905..3e4af5def 100644 --- a/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py +++ b/python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py @@ -22,6 +22,7 @@ @pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES) def test_event_deserializer(expected: DeserializeableShape, given: EventMessage): source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamDeserializer().deserialize(deserializer) assert result == expected @@ -30,6 +31,7 @@ def test_event_deserializer(expected: DeserializeableShape, given: EventMessage) def test_deserialize_initial_request(): expected, given = INITIAL_REQUEST_CASE source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -38,6 +40,7 @@ def test_deserialize_initial_request(): def test_deserialize_initial_response(): expected, given = INITIAL_RESPONSE_CASE source = Event.decode(BytesIO(given.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) result = EventStreamOperationInputOutput.deserialize(deserializer) assert result == expected @@ -52,6 +55,7 @@ def test_deserialize_unmodeled_error(): } ) source = Event.decode(BytesIO(message.encode())) + assert source is not None deserializer = EventDeserializer(event=source, payload_codec=JSONCodec()) with pytest.raises(UnmodeledEventError, match="InternalError"): diff --git a/python-packages/aws-event-stream/tests/unit/test_events.py b/python-packages/aws-event-stream/tests/unit/test_events.py index d6c534bb8..97f97a276 100644 --- a/python-packages/aws-event-stream/tests/unit/test_events.py +++ b/python-packages/aws-event-stream/tests/unit/test_events.py @@ -25,6 +25,7 @@ from aws_event_stream.exceptions import ( ChecksumMismatch, DuplicateHeader, + InvalidEventBytes, InvalidHeadersLength, InvalidHeaderValueLength, InvalidIntegerValue, @@ -381,6 +382,8 @@ ), ) +EMPTY_SOURCE = (b"", None) + # Tuples of encoded messages and their expected decoded output POSITIVE_CASES = [ EMPTY_MESSAGE, # standard @@ -398,6 +401,7 @@ PAYLOAD_ONE_STR_HEADER, # standard ALL_HEADERS_TYPES, # standard ERROR_EVENT_MESSAGE, + EMPTY_SOURCE, ] CORRUPTED_HEADERS_LENGTH = ( @@ -489,47 +493,54 @@ InvalidPayloadLength, ) +TRUNCATED_PRELUDE = (b"\x00", InvalidEventBytes) + +MISSING_PRELUDE_CRC_BYTES = (b"\x00\x00\x00\x16", InvalidEventBytes) + +MISSING_MESSAGE_CRC_BYTES = ( + ( + b"\x00\x00\x00\x10" # total length + b"\x00\x00\x00\x00" # headers length + b"\x05\xc2\x48\xeb" # prelude crc + ), + InvalidEventBytes, +) + # Tuples of encoded messages and their expected exception -NEGATIVE_CASES = [ - CORRUPTED_LENGTH, # standard - CORRUPTED_PAYLOAD, # standard - CORRUPTED_HEADERS, # standard - CORRUPTED_HEADERS_LENGTH, # standard - DUPLICATE_HEADER, - INVALID_HEADERS_LENGTH, - INVALID_HEADER_VALUE_LENGTH, - INVALID_PAYLOAD_LENGTH, -] +NEGATIVE_CASES = { + "corrupted-length": CORRUPTED_LENGTH, # standard + "corrupted-payload": CORRUPTED_PAYLOAD, # standard + "corrupted-headers": CORRUPTED_HEADERS, # standard + "corrupted-headers-length": CORRUPTED_HEADERS_LENGTH, # standard + "duplicate-header": DUPLICATE_HEADER, + "invalid-headers-length": INVALID_HEADERS_LENGTH, + "invalid-header-value-length": INVALID_HEADER_VALUE_LENGTH, + "invalid-payload-length": INVALID_PAYLOAD_LENGTH, + "truncated-prelude": TRUNCATED_PRELUDE, + "missing-prelude-crc-bytes": MISSING_PRELUDE_CRC_BYTES, + "missing-message-crc-bytes": MISSING_MESSAGE_CRC_BYTES, +} @pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) -def test_decode(encoded: bytes, expected: Event): +def test_decode(encoded: bytes, expected: Event | None): assert Event.decode(BytesIO(encoded)) == expected @pytest.mark.parametrize("encoded,expected", POSITIVE_CASES) -async def test_decode_async(encoded: bytes, expected: Event): +async def test_decode_async(encoded: bytes, expected: Event | None): assert await Event.decode_async(AsyncBytesReader(encoded)) == expected -@pytest.mark.parametrize("expected,event", POSITIVE_CASES) +@pytest.mark.parametrize( + "expected,event", [c for c in POSITIVE_CASES if c[1] is not None] +) def test_encode(expected: bytes, event: Event): assert event.message.encode() == expected @pytest.mark.parametrize( - "encoded,expected", - NEGATIVE_CASES, - ids=[ - "corrupted-length", - "corrupted-payload", - "corrupted-headers", - "corrupted-headers-length", - "duplicate-header", - "invalid-headers-length", - "invalid-header-value-length", - "invalid-payload-length", - ], + "encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys() ) def test_negative_cases(encoded: bytes, expected: type[Exception]): with pytest.raises(expected): @@ -537,18 +548,7 @@ def test_negative_cases(encoded: bytes, expected: type[Exception]): @pytest.mark.parametrize( - "encoded,expected", - NEGATIVE_CASES, - ids=[ - "corrupted-length", - "corrupted-payload", - "corrupted-headers", - "corrupted-headers-length", - "duplicate-header", - "invalid-headers-length", - "invalid-header-value-length", - "invalid-payload-length", - ], + "encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys() ) async def test_negative_cases_async(encoded: bytes, expected: type[Exception]): with pytest.raises(expected):