Skip to content

Commit cdfd725

Browse files
committed
Fix payload codec usage
- Revert "Store data_converter on WorkflowInstanceDetails instead of converter classes" - Don't require payload codec on workflow instance - Don't stack contexts Revert "Store data_converter on WorkflowInstanceDetails instead of converter classes" This reverts commit f869af87ba5bffb891f8abb3189e655921247ca3.
1 parent 7679570 commit cdfd725

File tree

5 files changed

+66
-47
lines changed

5 files changed

+66
-47
lines changed

temporalio/worker/_workflow.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,12 @@ async def _handle_activation(
249249
await self._handle_cache_eviction(act, cache_remove_job)
250250
return
251251

252-
data_converter = self._data_converter
253252
# Build default success completion (e.g. remove-job-only activations)
254253
completion = (
255254
temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion()
256255
)
257256
completion.successful.SetInParent()
258-
workflow = None
257+
workflow = workflow_id = None
259258
try:
260259
if LOG_PROTOS:
261260
logger.debug("Received workflow activation:\n%s", act)
@@ -275,17 +274,14 @@ async def _handle_activation(
275274
"Cache already exists for activation with initialize job"
276275
)
277276

278-
data_converter = self._data_converter._with_context(
279-
temporalio.converter.WorkflowSerializationContext(
280-
namespace=self._namespace,
281-
workflow_id=workflow_id,
282-
)
283-
)
284-
if data_converter.payload_codec:
277+
if self._data_converter.payload_codec:
285278
if not workflow:
286-
payload_codec = data_converter.payload_codec
279+
payload_codec = self._data_converter.payload_codec
287280
else:
288-
payload_codec = _CommandAwarePayloadCodec(workflow.instance)
281+
payload_codec = _CommandAwarePayloadCodec(
282+
workflow.instance,
283+
self._data_converter.payload_codec,
284+
)
289285
await temporalio.bridge.worker.decode_activation(
290286
act,
291287
payload_codec.decode,
@@ -339,6 +335,14 @@ async def _handle_activation(
339335

340336
completion.failed.failure.SetInParent()
341337
try:
338+
data_converter = self._data_converter
339+
if workflow_id:
340+
data_converter = data_converter._with_context(
341+
temporalio.converter.WorkflowSerializationContext(
342+
namespace=self._namespace,
343+
workflow_id=workflow_id,
344+
)
345+
)
342346
data_converter.failure_converter.to_failure(
343347
err,
344348
data_converter.payload_converter,
@@ -356,8 +360,11 @@ async def _handle_activation(
356360
completion.run_id = act.run_id
357361

358362
# Encode completion
359-
if data_converter.payload_codec and workflow:
360-
payload_codec = _CommandAwarePayloadCodec(workflow.instance)
363+
if self._data_converter.payload_codec and workflow:
364+
payload_codec = _CommandAwarePayloadCodec(
365+
workflow.instance,
366+
self._data_converter.payload_codec,
367+
)
361368
try:
362369
await temporalio.bridge.worker.encode_completion(
363370
completion,
@@ -572,7 +579,8 @@ def _create_workflow_instance(
572579

573580
# Create instance from details
574581
det = WorkflowInstanceDetails(
575-
data_converter=self._data_converter,
582+
payload_converter_class=self._data_converter.payload_converter_class,
583+
failure_converter_class=self._data_converter.failure_converter_class,
576584
interceptor_classes=self._interceptor_classes,
577585
defn=defn,
578586
info=info,
@@ -722,8 +730,10 @@ class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec):
722730
def __init__(
723731
self,
724732
instance: WorkflowInstance,
733+
context_free_payload_codec: temporalio.converter.PayloadCodec,
725734
):
726735
self.instance = instance
736+
self.context_free_payload_codec = context_free_payload_codec
727737

728738
async def encode(
729739
self,
@@ -738,10 +748,10 @@ async def decode(
738748
return await self._get_current_command_codec().decode(payloads)
739749

740750
def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec:
741-
seq = temporalio.bridge._visitor.current_command_seq.get()
742-
codec = self.instance.get_payload_codec(seq)
743-
assert codec, "Payload codec must be set on the data converter"
744-
return codec
751+
return self.instance.get_payload_codec_with_context(
752+
self.context_free_payload_codec,
753+
temporalio.bridge._visitor.current_command_seq.get(),
754+
)
745755

746756

747757
class _InterruptDeadlockError(BaseException):

temporalio/worker/_workflow_instance.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ def set_worker_level_failure_exception_types(
135135
class WorkflowInstanceDetails:
136136
"""Immutable details for creating a workflow instance."""
137137

138-
data_converter: temporalio.converter.DataConverter
138+
payload_converter_class: Type[temporalio.converter.PayloadConverter]
139+
failure_converter_class: Type[temporalio.converter.FailureConverter]
139140
interceptor_classes: Sequence[Type[WorkflowInboundInterceptor]]
140141
defn: temporalio.workflow._Definition
141142
info: temporalio.workflow.Info
@@ -168,9 +169,11 @@ def activate(
168169
raise NotImplementedError
169170

170171
@abstractmethod
171-
def get_payload_codec(
172-
self, command_seq: Optional[int]
173-
) -> Optional[temporalio.converter.PayloadCodec]:
172+
def get_payload_codec_with_context(
173+
self,
174+
payload_codec: temporalio.converter.PayloadCodec,
175+
command_seq: Optional[int],
176+
) -> temporalio.converter.PayloadCodec:
174177
"""Return a payload codec with appropriate serialization context.
175178
176179
Args:
@@ -224,13 +227,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
224227
self._defn = det.defn
225228
self._workflow_input: Optional[ExecuteWorkflowInput] = None
226229
self._info = det.info
227-
self._context_free_payload_codec = det.data_converter.payload_codec
228-
self._context_free_payload_converter = (
229-
det.data_converter.payload_converter_class()
230-
)
231-
self._context_free_failure_converter = (
232-
det.data_converter.failure_converter_class()
233-
)
230+
self._context_free_payload_converter = det.payload_converter_class()
231+
self._context_free_failure_converter = det.failure_converter_class()
234232
self._payload_converter, self._failure_converter = (
235233
self._converters_with_context(
236234
temporalio.converter.WorkflowSerializationContext(
@@ -2099,10 +2097,11 @@ def _converters_with_context(
20992097
return payload_converter, failure_converter
21002098

21012099
# _WorkflowInstanceImpl.get_pending_command_serialization_context
2102-
def get_payload_codec(
2103-
self, command_seq: Optional[int]
2104-
) -> Optional[temporalio.converter.PayloadCodec]:
2105-
payload_codec = self._context_free_payload_codec
2100+
def get_payload_codec_with_context(
2101+
self,
2102+
payload_codec: temporalio.converter.PayloadCodec,
2103+
command_seq: Optional[int],
2104+
) -> temporalio.converter.PayloadCodec:
21062105
if not isinstance(
21072106
payload_codec,
21082107
temporalio.converter.WithSerializationContext,

temporalio/worker/workflow_sandbox/_in_sandbox.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ def activate(
8181
"""Send activation to this instance."""
8282
return self.instance.activate(act)
8383

84-
def get_payload_codec(
85-
self, command_seq: Optional[int]
86-
) -> Optional[temporalio.converter.PayloadCodec]:
87-
"""Get payload codec."""
88-
return self.instance.get_payload_codec(command_seq)
84+
def get_payload_codec_with_context(
85+
self,
86+
payload_codec: temporalio.converter.PayloadCodec,
87+
command_seq: Optional[int],
88+
) -> temporalio.converter.PayloadCodec:
89+
"""Get payload codec with context."""
90+
return self.instance.get_payload_codec_with_context(payload_codec, command_seq)

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None:
7777
# Just create with fake info which validates
7878
self.create_instance(
7979
WorkflowInstanceDetails(
80-
data_converter=temporalio.converter.DataConverter.default,
80+
payload_converter_class=temporalio.converter.DataConverter.default.payload_converter_class,
81+
failure_converter_class=temporalio.converter.DataConverter.default.failure_converter_class,
8182
interceptor_classes=[],
8283
defn=defn,
8384
# Just use fake info during validation
@@ -185,16 +186,19 @@ def _run_code(self, code: str, **extra_globals: Any) -> None:
185186
def get_thread_id(self) -> Optional[int]:
186187
return self._current_thread_id
187188

188-
def get_payload_codec(
189-
self, command_seq: Optional[int]
190-
) -> Optional[temporalio.converter.PayloadCodec]:
189+
def get_payload_codec_with_context(
190+
self,
191+
payload_codec: temporalio.converter.PayloadCodec,
192+
command_seq: Optional[int],
193+
) -> temporalio.converter.PayloadCodec:
191194
# Forward call to the sandboxed instance
192195
self.importer.restriction_context.is_runtime = True
193196
try:
194197
self._run_code(
195198
"with __temporal_importer.applied():\n"
196-
" __temporal_codec = __temporal_in_sandbox.get_payload_codec(__temporal_command_seq)\n",
199+
" __temporal_codec = __temporal_in_sandbox.get_payload_codec_with_context(__temporal_payload_codec, __temporal_command_seq)\n",
197200
__temporal_importer=self.importer,
201+
__temporal_payload_codec=payload_codec,
198202
__temporal_command_seq=command_seq,
199203
)
200204
return self.globals_and_locals.pop("__temporal_codec", None) # type: ignore

tests/worker/test_workflow.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,10 +1610,14 @@ def activate(self, act: WorkflowActivation) -> WorkflowActivationCompletion:
16101610
self._runner._pairs.append((act, comp))
16111611
return comp
16121612

1613-
def get_payload_codec(
1614-
self, command_seq: Optional[int]
1615-
) -> Optional[temporalio.converter.PayloadCodec]:
1616-
return self._unsandboxed.get_payload_codec(command_seq)
1613+
def get_payload_codec_with_context(
1614+
self,
1615+
payload_codec: temporalio.converter.PayloadCodec,
1616+
command_seq: Optional[int],
1617+
) -> temporalio.converter.PayloadCodec:
1618+
return self._unsandboxed.get_payload_codec_with_context(
1619+
payload_codec, command_seq
1620+
)
16171621

16181622

16191623
async def test_workflow_with_custom_runner(client: Client):

0 commit comments

Comments
 (0)