Skip to content

Commit 3a6c568

Browse files
committed
Cleanup
1 parent fe25627 commit 3a6c568

File tree

3 files changed

+40
-21
lines changed

3 files changed

+40
-21
lines changed

temporalio/converter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class ActivitySerializationContext(BaseWorkflowSerializationContext):
136136
is_local: bool
137137

138138

139+
# TODO: duck typing or nominal typing?
139140
class WithSerializationContext(ABC):
140141
"""Interface for objects that can use serialization context.
141142
@@ -341,7 +342,6 @@ def __init__(self, *converters: EncodingPayloadConverter) -> None:
341342
Args:
342343
converters: Payload converters to delegate to, in order.
343344
"""
344-
# Insertion order preserved here since Python 3.7
345345
self.converters = {c.encoding.encode(): c for c in converters}
346346

347347
def to_payloads(
@@ -407,7 +407,7 @@ def from_payloads(
407407
return values
408408

409409
def with_context(self, context: SerializationContext) -> CompositePayloadConverter:
410-
"""Return a new instance with the given context."""
410+
"""Return a new instance with context set on the component converters"""
411411
return CompositePayloadConverter(
412412
*(
413413
c.with_context(context)

temporalio/worker/_workflow_instance.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2087,7 +2087,12 @@ def _converters_with_context(
20872087
temporalio.converter.PayloadConverter,
20882088
temporalio.converter.FailureConverter,
20892089
]:
2090-
"""Construct workflow payload and failure converters with the given context."""
2090+
"""Construct workflow payload and failure converters with the given context.
2091+
2092+
This plays a similar role to DataConverter._with_context, but operates on PayloadConverter
2093+
and FailureConverter only (since payload encoding/decoding is done by the worker, outside
2094+
the workflowsandbox).
2095+
"""
20912096
payload_converter = self._context_free_payload_converter
20922097
failure_converter = self._context_free_failure_converter
20932098
if isinstance(payload_converter, temporalio.converter.WithSerializationContext):

tests/test_serialization_context.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from pydantic import BaseModel
2222
from typing_extensions import Never
2323

24+
import temporalio.api.common.v1
25+
import temporalio.api.failure.v1
2426
from temporalio import activity, workflow
25-
from temporalio.api.common.v1 import Payload
26-
from temporalio.api.failure.v1 import Failure
2727
from temporalio.client import Client, WorkflowFailureError, WorkflowUpdateFailedError
2828
from temporalio.common import RetryPolicy
2929
from temporalio.contrib.pydantic import PydanticJSONPlainPayloadConverter
@@ -82,7 +82,7 @@ def with_context(
8282
converter.context = context
8383
return converter
8484

85-
def to_payload(self, value: Any) -> Optional[Payload]:
85+
def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
8686
if not isinstance(value, TraceData):
8787
return None
8888
if isinstance(self.context, WorkflowSerializationContext):
@@ -106,7 +106,11 @@ def to_payload(self, value: Any) -> Optional[Payload]:
106106
payload.metadata["encoding"] = self.encoding.encode()
107107
return payload
108108

109-
def from_payload(self, payload: Payload, type_hint: Optional[Type] = None) -> Any:
109+
def from_payload(
110+
self,
111+
payload: temporalio.api.common.v1.Payload,
112+
type_hint: Optional[Type] = None,
113+
) -> Any:
110114
value = JSONPlainPayloadConverter().from_payload(payload, TraceData)
111115
assert isinstance(value, TraceData)
112116
if isinstance(self.context, WorkflowSerializationContext):
@@ -1002,7 +1006,7 @@ def __init__(self):
10021006

10031007
def with_context(
10041008
self, context: Optional[SerializationContext]
1005-
) -> "FailureConverterWithContext":
1009+
) -> FailureConverterWithContext:
10061010
converter = FailureConverterWithContext()
10071011
converter.context = context
10081012
return converter
@@ -1011,7 +1015,7 @@ def to_failure(
10111015
self,
10121016
exception: BaseException,
10131017
payload_converter: PayloadConverter,
1014-
failure: Failure,
1018+
failure: temporalio.api.failure.v1.Failure,
10151019
) -> None:
10161020
assert isinstance(
10171021
self.context, (WorkflowSerializationContext, ActivitySerializationContext)
@@ -1025,7 +1029,9 @@ def to_failure(
10251029
super().to_failure(exception, payload_converter, failure)
10261030

10271031
def from_failure(
1028-
self, failure: Failure, payload_converter: PayloadConverter
1032+
self,
1033+
failure: temporalio.api.failure.v1.Failure,
1034+
payload_converter: PayloadConverter,
10291035
) -> BaseException:
10301036
assert isinstance(
10311037
self.context, (WorkflowSerializationContext, ActivitySerializationContext)
@@ -1132,12 +1138,14 @@ def __init__(self):
11321138

11331139
def with_context(
11341140
self, context: Optional[SerializationContext]
1135-
) -> "PayloadCodecWithContext":
1141+
) -> PayloadCodecWithContext:
11361142
codec = PayloadCodecWithContext()
11371143
codec.context = context
11381144
return codec
11391145

1140-
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
1146+
async def encode(
1147+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
1148+
) -> List[temporalio.api.common.v1.Payload]:
11411149
assert self.context
11421150
if isinstance(self.context, ActivitySerializationContext):
11431151
test_traces[self.context.workflow_id].append(
@@ -1156,7 +1164,9 @@ async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
11561164
)
11571165
return list(payloads)
11581166

1159-
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
1167+
async def decode(
1168+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
1169+
) -> List[temporalio.api.common.v1.Payload]:
11601170
assert self.context
11611171
if isinstance(self.context, ActivitySerializationContext):
11621172
test_traces[self.context.workflow_id].append(
@@ -1439,20 +1449,24 @@ def with_context(
14391449
codec.context = context
14401450
return codec
14411451

1442-
async def encode(self, payloads: Sequence[Payload]) -> List[Payload]:
1452+
async def encode(
1453+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
1454+
) -> List[temporalio.api.common.v1.Payload]:
14431455
[payload] = payloads
14441456
return [
1445-
Payload(
1457+
temporalio.api.common.v1.Payload(
14461458
metadata=payload.metadata,
14471459
data=json.dumps(self._get_encryption_key()).encode(),
14481460
)
14491461
]
14501462

1451-
async def decode(self, payloads: Sequence[Payload]) -> List[Payload]:
1463+
async def decode(
1464+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
1465+
) -> List[temporalio.api.common.v1.Payload]:
14521466
[payload] = payloads
14531467
assert json.loads(payload.data.decode()) == self._get_encryption_key()
14541468
metadata = dict(payload.metadata)
1455-
return [Payload(metadata=metadata, data=b'"inbound"')]
1469+
return [temporalio.api.common.v1.Payload(metadata=metadata, data=b'"inbound"')]
14561470

14571471
def _get_encryption_key(self) -> str:
14581472
context = (
@@ -1594,8 +1608,8 @@ def with_context(
15941608
return codec
15951609

15961610
async def _assert_context_iff_not_nexus(
1597-
self, payloads: Sequence[Payload]
1598-
) -> List[Payload]:
1611+
self, payloads: Sequence[temporalio.api.common.v1.Payload]
1612+
) -> List[temporalio.api.common.v1.Payload]:
15991613
[payload] = payloads
16001614
assert bool(self.context) == (payload.data.decode() != '"nexus-data"')
16011615
return list(payloads)
@@ -1670,12 +1684,12 @@ def __init__(self):
16701684

16711685
def with_context(
16721686
self, context: Optional[SerializationContext]
1673-
) -> "PydanticJSONConverterWithContext":
1687+
) -> PydanticJSONConverterWithContext:
16741688
converter = PydanticJSONConverterWithContext()
16751689
converter.context = context
16761690
return converter
16771691

1678-
def to_payload(self, value: Any) -> Optional[Payload]:
1692+
def to_payload(self, value: Any) -> Optional[temporalio.api.common.v1.Payload]:
16791693
if isinstance(value, PydanticData) and self.context:
16801694
if isinstance(self.context, WorkflowSerializationContext):
16811695
value.trace.append(f"wf_{self.context.workflow_id}")

0 commit comments

Comments
 (0)