Skip to content

Commit 4e71c2e

Browse files
authored
💥 Payload limit configuration and validation (#1288)
1 parent bd63236 commit 4e71c2e

File tree

9 files changed

+888
-112
lines changed

9 files changed

+888
-112
lines changed

‎temporalio/bridge/worker.py‎

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,6 @@ async def encode_completion(
315315
encode_headers: bool,
316316
) -> None:
317317
"""Encode all payloads in the completion."""
318-
if data_converter._encode_payload_has_effect:
319-
await CommandAwarePayloadVisitor(
320-
skip_search_attributes=True, skip_headers=not encode_headers
321-
).visit(_Visitor(data_converter._encode_payload_sequence), completion)
318+
await CommandAwarePayloadVisitor(
319+
skip_search_attributes=True, skip_headers=not encode_headers
320+
).visit(_Visitor(data_converter._encode_payload_sequence), completion)

‎temporalio/client.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9178,7 +9178,7 @@ async def _apply_headers(
91789178
) -> None:
91799179
if source is None:
91809180
return
9181-
if encode_headers and data_converter._encode_payload_has_effect:
9181+
if encode_headers:
91829182
for payload in source.values():
91839183
payload.CopyFrom(await data_converter._encode_payload(payload))
91849184
temporalio.common._apply_headers(source, dest)

‎temporalio/converter.py‎

Lines changed: 108 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,10 @@ def to_failure(
930930
else:
931931
# Convert to failure error
932932
failure_error = temporalio.exceptions.ApplicationError(
933-
str(exception), type=exception.__class__.__name__
933+
str(exception),
934+
type="PayloadSizeError"
935+
if isinstance(exception, _PayloadSizeError)
936+
else exception.__class__.__name__,
934937
)
935938
failure_error.__traceback__ = exception.__traceback__
936939
failure_error.__cause__ = exception.__cause__
@@ -1220,6 +1223,46 @@ def __init__(self) -> None:
12201223
super().__init__(encode_common_attributes=True)
12211224

12221225

1226+
@dataclass(frozen=True)
1227+
class PayloadLimitsConfig:
1228+
"""Configuration for when payload sizes exceed limits."""
1229+
1230+
memo_size_warning: int = 2 * 1024
1231+
"""The limit (in bytes) at which a memo size warning is logged."""
1232+
1233+
payload_size_warning: int = 512 * 1024
1234+
"""The limit (in bytes) at which a payload size warning is logged."""
1235+
1236+
1237+
class PayloadSizeWarning(RuntimeWarning):
1238+
"""The size of payloads is above the warning limit."""
1239+
1240+
1241+
class _PayloadSizeError(temporalio.exceptions.TemporalError):
1242+
"""Error raised when payloads size exceeds payload size limits."""
1243+
1244+
def __init__(self, message: str):
1245+
"""Initialize a payloads size error."""
1246+
super().__init__(message)
1247+
self._message = message
1248+
1249+
@property
1250+
def message(self) -> str:
1251+
"""Message."""
1252+
return self._message
1253+
1254+
1255+
@dataclass(frozen=True)
1256+
class _ServerPayloadErrorLimits:
1257+
"""Error limits for payloads as described by the Temporal server."""
1258+
1259+
memo_size_error: int
1260+
"""The limit (in bytes) at which a memo size error is raised."""
1261+
1262+
payload_size_error: int
1263+
"""The limit (in bytes) at which a payload size error is raised."""
1264+
1265+
12231266
@dataclass(frozen=True)
12241267
class DataConverter(WithSerializationContext):
12251268
"""Data converter for converting and encoding payloads to/from Python values.
@@ -1243,9 +1286,15 @@ class DataConverter(WithSerializationContext):
12431286
failure_converter: FailureConverter = dataclasses.field(init=False)
12441287
"""Failure converter created from the :py:attr:`failure_converter_class`."""
12451288

1289+
payload_limits: PayloadLimitsConfig = PayloadLimitsConfig()
1290+
"""Settings for payload size limits."""
1291+
12461292
default: ClassVar[DataConverter]
12471293
"""Singleton default data converter."""
12481294

1295+
_payload_error_limits: _ServerPayloadErrorLimits | None = None
1296+
"""Server-reported limits for payloads."""
1297+
12491298
def __post_init__(self) -> None: # noqa: D105
12501299
object.__setattr__(self, "payload_converter", self.payload_converter_class())
12511300
object.__setattr__(self, "failure_converter", self.failure_converter_class())
@@ -1347,6 +1396,11 @@ def with_context(self, context: SerializationContext) -> Self:
13471396
object.__setattr__(cloned, "failure_converter", failure_converter)
13481397
return cloned
13491398

1399+
def _with_payload_error_limits(
1400+
self, limits: _ServerPayloadErrorLimits | None
1401+
) -> DataConverter:
1402+
return dataclasses.replace(self, _payload_error_limits=limits)
1403+
13501404
async def _decode_memo(
13511405
self,
13521406
source: temporalio.api.common.v1.Memo,
@@ -1385,30 +1439,38 @@ async def _encode_memo_existing(
13851439
if not isinstance(v, temporalio.api.common.v1.Payload):
13861440
payload = (await self.encode([v]))[0]
13871441
memo.fields[k].CopyFrom(payload)
1442+
# Memos have their field payloads validated all together in one unit
1443+
DataConverter._validate_limits(
1444+
list(memo.fields.values()),
1445+
self._payload_error_limits.memo_size_error
1446+
if self._payload_error_limits
1447+
else None,
1448+
"[TMPRL1103] Attempted to upload memo with size that exceeded the error limit.",
1449+
self.payload_limits.memo_size_warning,
1450+
"[TMPRL1103] Attempted to upload memo with size that exceeded the warning limit.",
1451+
)
13881452

13891453
async def _encode_payload(
13901454
self, payload: temporalio.api.common.v1.Payload
13911455
) -> temporalio.api.common.v1.Payload:
13921456
if self.payload_codec:
13931457
payload = (await self.payload_codec.encode([payload]))[0]
1458+
self._validate_payload_limits([payload])
13941459
return payload
13951460

13961461
async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads):
13971462
if self.payload_codec:
13981463
await self.payload_codec.encode_wrapper(payloads)
1464+
self._validate_payload_limits(payloads.payloads)
13991465

14001466
async def _encode_payload_sequence(
14011467
self, payloads: Sequence[temporalio.api.common.v1.Payload]
14021468
) -> list[temporalio.api.common.v1.Payload]:
1403-
if not self.payload_codec:
1404-
return list(payloads)
1405-
return await self.payload_codec.encode(payloads)
1406-
1407-
# Temporary shortcircuit detection while the _encode_* methods may no-op if
1408-
# a payload codec is not configured. Remove once those paths have more to them.
1409-
@property
1410-
def _encode_payload_has_effect(self) -> bool:
1411-
return self.payload_codec is not None
1469+
encoded_payloads = list(payloads)
1470+
if self.payload_codec:
1471+
encoded_payloads = await self.payload_codec.encode(encoded_payloads)
1472+
self._validate_payload_limits(encoded_payloads)
1473+
return encoded_payloads
14121474

14131475
async def _decode_payload(
14141476
self, payload: temporalio.api.common.v1.Payload
@@ -1465,6 +1527,42 @@ async def _apply_to_failure_payloads(
14651527
if failure.HasField("cause"):
14661528
await DataConverter._apply_to_failure_payloads(failure.cause, cb)
14671529

1530+
def _validate_payload_limits(
1531+
self,
1532+
payloads: Sequence[temporalio.api.common.v1.Payload],
1533+
):
1534+
DataConverter._validate_limits(
1535+
payloads,
1536+
self._payload_error_limits.payload_size_error
1537+
if self._payload_error_limits
1538+
else None,
1539+
"[TMPRL1103] Attempted to upload payloads with size that exceeded the error limit.",
1540+
self.payload_limits.payload_size_warning,
1541+
"[TMPRL1103] Attempted to upload payloads with size that exceeded the warning limit.",
1542+
)
1543+
1544+
@staticmethod
1545+
def _validate_limits(
1546+
payloads: Sequence[temporalio.api.common.v1.Payload],
1547+
error_limit: int | None,
1548+
error_message: str,
1549+
warning_limit: int,
1550+
warning_message: str,
1551+
):
1552+
total_size = sum(payload.ByteSize() for payload in payloads)
1553+
1554+
if error_limit and error_limit > 0 and total_size > error_limit:
1555+
raise _PayloadSizeError(
1556+
f"{error_message} Size: {total_size} bytes, Limit: {error_limit} bytes"
1557+
)
1558+
1559+
if warning_limit > 0 and total_size > warning_limit:
1560+
# TODO: Use a context aware logger to log extra information about workflow/activity/etc
1561+
warnings.warn(
1562+
f"{warning_message} Size: {total_size} bytes, Limit: {warning_limit} bytes",
1563+
PayloadSizeWarning,
1564+
)
1565+
14681566

14691567
DefaultPayloadConverter.default_encoding_payload_converters = (
14701568
BinaryNullPayloadConverter(),

0 commit comments

Comments
 (0)