Skip to content

Commit 2758bf9

Browse files
Handle empty and incomplete event bytes
This adds in explicit handling for both emtpy and incomplete event bytes. If nothing is able to be read from the source, event decoders will return None. If there are bytes there, but they're truncated, then an explicit error is thrown that wraps what would otherwise be a `struct.error`. This is only applied for truncations that would not already be caught by checksum validation.
1 parent 91df81b commit 2758bf9

File tree

5 files changed

+89
-47
lines changed

5 files changed

+89
-47
lines changed

python-packages/aws-event-stream/aws_event_stream/_private/deserializers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(
3838

3939
async def receive(self) -> E | None:
4040
event = await Event.decode_async(self._source)
41+
if event is None:
42+
return None
43+
4144
deserializer = EventDeserializer(
4245
event=event,
4346
payload_codec=self._payload_codec,

python-packages/aws-event-stream/aws_event_stream/events.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .exceptions import (
2525
ChecksumMismatch,
2626
DuplicateHeader,
27+
InvalidEventBytes,
2728
InvalidHeadersLength,
2829
InvalidHeaderValue,
2930
InvalidHeaderValueLength,
@@ -286,27 +287,41 @@ class Event:
286287
"""
287288

288289
@classmethod
289-
def decode(cls, source: BytesReader) -> Self:
290+
def decode(cls, source: BytesReader) -> Self | None:
290291
"""Decode an event from a byte stream.
291292
292293
:param source: An object to read event bytes from. It must have a `read` method
293294
that accepts a number of bytes to read.
294295
295-
:returns: An Event representing the next event on the source.
296+
:returns: An Event representing the next event on the source, or None if no
297+
data can be read from the source.
296298
"""
297299

298300
prelude_bytes = source.read(8)
301+
if not prelude_bytes:
302+
# If nothing can be read from the source, return None. If bytes are missing
303+
# later, that indicates a problem with the source and therefore will result
304+
# in an exception.
305+
return None
306+
299307
prelude_crc_bytes = source.read(4)
300-
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
308+
try:
309+
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
310+
total_length, headers_length = unpack("!II", prelude_bytes)
311+
except struct.error as e:
312+
raise InvalidEventBytes() from e
301313

302-
total_length, headers_length = unpack("!II", prelude_bytes)
303314
_validate_checksum(prelude_bytes, prelude_crc)
304315
prelude = EventPrelude(
305316
total_length=total_length, headers_length=headers_length, crc=prelude_crc
306317
)
307318

308319
message_bytes = source.read(total_length - _MESSAGE_METADATA_SIZE)
309-
crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0]
320+
try:
321+
crc: int = _DecodeUtils.unpack_uint32(source.read(4))[0]
322+
except struct.error as e:
323+
raise InvalidEventBytes() from e
324+
310325
_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)
311326

312327
message = EventMessage(
@@ -316,27 +331,41 @@ def decode(cls, source: BytesReader) -> Self:
316331
return cls(prelude, message, crc)
317332

318333
@classmethod
319-
async def decode_async(cls, source: AsyncByteStream) -> Self:
334+
async def decode_async(cls, source: AsyncByteStream) -> Self | None:
320335
"""Decode an event from an async byte stream.
321336
322337
:param source: An object to read event bytes from. It must have a `read` method
323338
that accepts a number of bytes to read.
324339
325-
:returns: An Event representing the next event on the source.
340+
:returns: An Event representing the next event on the source, or None if no
341+
data can be read from the source.
326342
"""
327343

328344
prelude_bytes = await source.read(8)
345+
if not prelude_bytes:
346+
# If nothing can be read from the source, return None. If bytes are missing
347+
# later, that indicates a problem with the source and therefore will result
348+
# in an exception.
349+
return None
350+
329351
prelude_crc_bytes = await source.read(4)
330-
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
352+
try:
353+
prelude_crc: int = _DecodeUtils.unpack_uint32(prelude_crc_bytes)[0]
354+
total_length, headers_length = unpack("!II", prelude_bytes)
355+
except struct.error as e:
356+
raise InvalidEventBytes() from e
331357

332-
total_length, headers_length = unpack("!II", prelude_bytes)
333358
_validate_checksum(prelude_bytes, prelude_crc)
334359
prelude = EventPrelude(
335360
total_length=total_length, headers_length=headers_length, crc=prelude_crc
336361
)
337362

338363
message_bytes = await source.read(total_length - _MESSAGE_METADATA_SIZE)
339-
crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0]
364+
try:
365+
crc: int = _DecodeUtils.unpack_uint32(await source.read(4))[0]
366+
except struct.error as e:
367+
raise InvalidEventBytes() from e
368+
340369
_validate_checksum(prelude_crc_bytes + message_bytes, crc, prelude_crc)
341370

342371
message = EventMessage(

python-packages/aws-event-stream/aws_event_stream/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ def __init__(self, size: str, value: int):
9393
super().__init__(message)
9494

9595

96+
class InvalidEventBytes(EventError):
97+
def __init__(self) -> None:
98+
message = "Invalid event bytes."
99+
super().__init__(message)
100+
101+
96102
class MissingInitialResponse(EventError):
97103
def __init__(self) -> None:
98104
super().__init__("Expected an initial response, but none was found.")

python-packages/aws-event-stream/tests/unit/_private/test_deserializers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
@pytest.mark.parametrize("expected,given", EVENT_STREAM_SERDE_CASES)
2323
def test_event_deserializer(expected: DeserializeableShape, given: EventMessage):
2424
source = Event.decode(BytesIO(given.encode()))
25+
assert source is not None
2526
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
2627
result = EventStreamDeserializer().deserialize(deserializer)
2728
assert result == expected
@@ -30,6 +31,7 @@ def test_event_deserializer(expected: DeserializeableShape, given: EventMessage)
3031
def test_deserialize_initial_request():
3132
expected, given = INITIAL_REQUEST_CASE
3233
source = Event.decode(BytesIO(given.encode()))
34+
assert source is not None
3335
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
3436
result = EventStreamOperationInputOutput.deserialize(deserializer)
3537
assert result == expected
@@ -38,6 +40,7 @@ def test_deserialize_initial_request():
3840
def test_deserialize_initial_response():
3941
expected, given = INITIAL_RESPONSE_CASE
4042
source = Event.decode(BytesIO(given.encode()))
43+
assert source is not None
4144
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
4245
result = EventStreamOperationInputOutput.deserialize(deserializer)
4346
assert result == expected
@@ -52,6 +55,7 @@ def test_deserialize_unmodeled_error():
5255
}
5356
)
5457
source = Event.decode(BytesIO(message.encode()))
58+
assert source is not None
5559
deserializer = EventDeserializer(event=source, payload_codec=JSONCodec())
5660

5761
with pytest.raises(UnmodeledEventError, match="InternalError"):

python-packages/aws-event-stream/tests/unit/test_events.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aws_event_stream.exceptions import (
2626
ChecksumMismatch,
2727
DuplicateHeader,
28+
InvalidEventBytes,
2829
InvalidHeadersLength,
2930
InvalidHeaderValueLength,
3031
InvalidIntegerValue,
@@ -381,6 +382,8 @@
381382
),
382383
)
383384

385+
EMPTY_SOURCE = (b"", None)
386+
384387
# Tuples of encoded messages and their expected decoded output
385388
POSITIVE_CASES = [
386389
EMPTY_MESSAGE, # standard
@@ -398,6 +401,7 @@
398401
PAYLOAD_ONE_STR_HEADER, # standard
399402
ALL_HEADERS_TYPES, # standard
400403
ERROR_EVENT_MESSAGE,
404+
EMPTY_SOURCE,
401405
]
402406

403407
CORRUPTED_HEADERS_LENGTH = (
@@ -489,66 +493,62 @@
489493
InvalidPayloadLength,
490494
)
491495

496+
TRUNCATED_PRELUDE = (b"\x00", InvalidEventBytes)
497+
498+
MISSING_PRELUDE_CRC_BYTES = (b"\x00\x00\x00\x16", InvalidEventBytes)
499+
500+
MISSING_MESSAGE_CRC_BYTES = (
501+
(
502+
b"\x00\x00\x00\x10" # total length
503+
b"\x00\x00\x00\x00" # headers length
504+
b"\x05\xc2\x48\xeb" # prelude crc
505+
),
506+
InvalidEventBytes,
507+
)
508+
492509
# Tuples of encoded messages and their expected exception
493-
NEGATIVE_CASES = [
494-
CORRUPTED_LENGTH, # standard
495-
CORRUPTED_PAYLOAD, # standard
496-
CORRUPTED_HEADERS, # standard
497-
CORRUPTED_HEADERS_LENGTH, # standard
498-
DUPLICATE_HEADER,
499-
INVALID_HEADERS_LENGTH,
500-
INVALID_HEADER_VALUE_LENGTH,
501-
INVALID_PAYLOAD_LENGTH,
502-
]
510+
NEGATIVE_CASES = {
511+
"corrupted-length": CORRUPTED_LENGTH, # standard
512+
"corrupted-payload": CORRUPTED_PAYLOAD, # standard
513+
"corrupted-headers": CORRUPTED_HEADERS, # standard
514+
"corrupted-headers-length": CORRUPTED_HEADERS_LENGTH, # standard
515+
"duplicate-header": DUPLICATE_HEADER,
516+
"invalid-headers-length": INVALID_HEADERS_LENGTH,
517+
"invalid-header-value-length": INVALID_HEADER_VALUE_LENGTH,
518+
"invalid-payload-length": INVALID_PAYLOAD_LENGTH,
519+
"truncated-prelude": TRUNCATED_PRELUDE,
520+
"missing-prelude-crc-bytes": MISSING_PRELUDE_CRC_BYTES,
521+
"missing-message-crc-bytes": MISSING_MESSAGE_CRC_BYTES,
522+
}
503523

504524

505525
@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES)
506-
def test_decode(encoded: bytes, expected: Event):
526+
def test_decode(encoded: bytes, expected: Event | None):
507527
assert Event.decode(BytesIO(encoded)) == expected
508528

509529

510530
@pytest.mark.parametrize("encoded,expected", POSITIVE_CASES)
511-
async def test_decode_async(encoded: bytes, expected: Event):
531+
async def test_decode_async(encoded: bytes, expected: Event | None):
512532
assert await Event.decode_async(AsyncBytesReader(encoded)) == expected
513533

514534

515-
@pytest.mark.parametrize("expected,event", POSITIVE_CASES)
535+
@pytest.mark.parametrize(
536+
"expected,event", [c for c in POSITIVE_CASES if c[1] is not None]
537+
)
516538
def test_encode(expected: bytes, event: Event):
517539
assert event.message.encode() == expected
518540

519541

520542
@pytest.mark.parametrize(
521-
"encoded,expected",
522-
NEGATIVE_CASES,
523-
ids=[
524-
"corrupted-length",
525-
"corrupted-payload",
526-
"corrupted-headers",
527-
"corrupted-headers-length",
528-
"duplicate-header",
529-
"invalid-headers-length",
530-
"invalid-header-value-length",
531-
"invalid-payload-length",
532-
],
543+
"encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys()
533544
)
534545
def test_negative_cases(encoded: bytes, expected: type[Exception]):
535546
with pytest.raises(expected):
536547
Event.decode(BytesIO(encoded))
537548

538549

539550
@pytest.mark.parametrize(
540-
"encoded,expected",
541-
NEGATIVE_CASES,
542-
ids=[
543-
"corrupted-length",
544-
"corrupted-payload",
545-
"corrupted-headers",
546-
"corrupted-headers-length",
547-
"duplicate-header",
548-
"invalid-headers-length",
549-
"invalid-header-value-length",
550-
"invalid-payload-length",
551-
],
551+
"encoded,expected", NEGATIVE_CASES.values(), ids=NEGATIVE_CASES.keys()
552552
)
553553
async def test_negative_cases_async(encoded: bytes, expected: type[Exception]):
554554
with pytest.raises(expected):

0 commit comments

Comments
 (0)