Skip to content

Commit a4d3148

Browse files
committed
Command-aware codec for payload visitor
1 parent ab7c36e commit a4d3148

File tree

2 files changed

+108
-2
lines changed

2 files changed

+108
-2
lines changed

temporalio/worker/_workflow.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import temporalio.activity
2626
import temporalio.api.common.v1
27+
import temporalio.bridge._visitor
2728
import temporalio.bridge.client
2829
import temporalio.bridge.proto.workflow_activation
2930
import temporalio.bridge.proto.workflow_completion
@@ -254,6 +255,7 @@ async def _handle_activation(
254255
temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion()
255256
)
256257
completion.successful.SetInParent()
258+
workflow = None
257259
try:
258260
if LOG_PROTOS:
259261
logger.debug("Received workflow activation:\n%s", act)
@@ -280,9 +282,13 @@ async def _handle_activation(
280282
)
281283
)
282284
if data_converter.payload_codec:
285+
workflow_instance = workflow.instance if workflow else None
286+
payload_codec = _CommandAwarePayloadCodec(
287+
workflow_instance, data_converter.payload_codec
288+
)
283289
await temporalio.bridge.worker.decode_activation(
284290
act,
285-
data_converter.payload_codec,
291+
payload_codec.decode,
286292
decode_headers=self._encode_headers,
287293
)
288294
if not workflow:
@@ -348,13 +354,17 @@ async def _handle_activation(
348354
)
349355

350356
completion.run_id = act.run_id
357+
assert workflow
351358

352359
# Encode completion
353360
if data_converter.payload_codec:
361+
payload_codec = _CommandAwarePayloadCodec(
362+
workflow.instance, data_converter.payload_codec
363+
)
354364
try:
355365
await temporalio.bridge.worker.encode_completion(
356366
completion,
357-
data_converter.payload_codec,
367+
payload_codec.encode,
358368
encode_headers=self._encode_headers,
359369
)
360370
except Exception as err:
@@ -705,5 +715,47 @@ def attempt_deadlock_interruption(self) -> None:
705715
)
706716

707717

718+
class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec):
719+
"""A payload codec that sets serialization context for the associated command.
720+
721+
This codec responds to the :py:data:`temporalio.bridge._visitor.current_command_seq` context
722+
variable set by the payload visitor.
723+
"""
724+
725+
def __init__(
726+
self,
727+
instance: Optional[WorkflowInstance],
728+
base_codec: temporalio.converter.PayloadCodec,
729+
):
730+
self.instance = instance
731+
self.base_codec = base_codec
732+
733+
async def encode(
734+
self,
735+
payloads: Sequence[temporalio.api.common.v1.Payload],
736+
) -> List[temporalio.api.common.v1.Payload]:
737+
return await self._get_current_command_codec().encode(payloads)
738+
739+
async def decode(
740+
self,
741+
payloads: Sequence[temporalio.api.common.v1.Payload],
742+
) -> List[temporalio.api.common.v1.Payload]:
743+
return await self._get_current_command_codec().decode(payloads)
744+
745+
def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec:
746+
codec = self.base_codec
747+
if self.instance:
748+
if isinstance(
749+
self.base_codec, temporalio.converter.WithSerializationContext
750+
):
751+
if seq := temporalio.bridge._visitor.current_command_seq.get():
752+
if (
753+
context
754+
:= self.instance.get_pending_command_serialization_context(seq)
755+
):
756+
codec = self.base_codec.with_context(context)
757+
return codec
758+
759+
708760
class _InterruptDeadlockError(BaseException):
709761
pass

temporalio/worker/_workflow_instance.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,20 @@ def activate(
167167
"""
168168
raise NotImplementedError
169169

170+
@abstractmethod
171+
def get_pending_command_serialization_context(
172+
self, command_seq: int
173+
) -> Optional[temporalio.converter.SerializationContext]:
174+
"""Return the serialization context for a pending command.
175+
176+
Args:
177+
command_seq: The sequence number of the command.
178+
179+
Returns:
180+
The serialization context for the command, or None if not found.
181+
"""
182+
raise NotImplementedError
183+
170184
def get_thread_id(self) -> Optional[int]:
171185
"""Return the thread identifier that this workflow is running on.
172186
@@ -225,6 +239,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
225239
self._context_free_failure_converter,
226240
)
227241
)
242+
self._payload_codec = det.data_converter.payload_codec
228243

229244
self._extern_functions = det.extern_functions
230245
self._disable_eager_activity_execution = det.disable_eager_activity_execution
@@ -2089,6 +2104,45 @@ def _converters_with_context(
20892104
failure_converter = failure_converter.with_context(context)
20902105
return payload_converter, failure_converter
20912106

2107+
def get_pending_command_serialization_context(
2108+
self, command_seq: int
2109+
) -> Optional[temporalio.converter.SerializationContext]:
2110+
if isinstance(
2111+
self._payload_codec, temporalio.converter.WithSerializationContext
2112+
):
2113+
if command_seq in self._pending_activities:
2114+
handle = self._pending_activities[command_seq]
2115+
return temporalio.converter.ActivitySerializationContext(
2116+
namespace=self._info.namespace,
2117+
workflow_id=self._info.workflow_id,
2118+
workflow_type=self._info.workflow_type,
2119+
activity_type=handle._input.activity,
2120+
activity_task_queue=(
2121+
handle._input.task_queue or self._info.task_queue
2122+
if isinstance(handle._input, StartActivityInput)
2123+
else self._info.task_queue
2124+
),
2125+
is_local=isinstance(handle._input, StartLocalActivityInput),
2126+
)
2127+
2128+
elif command_seq in self._pending_child_workflows:
2129+
handle = self._pending_child_workflows[command_seq]
2130+
return temporalio.converter.WorkflowSerializationContext(
2131+
namespace=self._info.namespace,
2132+
workflow_id=handle._input.id,
2133+
)
2134+
2135+
elif command_seq in self._pending_external_signals:
2136+
_, workflow_id = self._pending_external_signals[command_seq]
2137+
return temporalio.converter.WorkflowSerializationContext(
2138+
namespace=self._info.namespace,
2139+
workflow_id=workflow_id,
2140+
)
2141+
2142+
elif command_seq in self._pending_nexus_operations:
2143+
# We don't set any context for nexus operations
2144+
pass
2145+
20922146
def _instantiate_workflow_object(self) -> Any:
20932147
if not self._workflow_input:
20942148
raise RuntimeError("Expected workflow input. This is a Python SDK bug.")

0 commit comments

Comments
 (0)