Skip to content

Commit 3bb9950

Browse files
committed
Use existing workflow context payload codec
1 parent 3103cd0 commit 3bb9950

File tree

5 files changed

+66
-45
lines changed

5 files changed

+66
-45
lines changed

temporalio/worker/_workflow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def _handle_activation(
290290
payload_codec = _CommandAwarePayloadCodec(
291291
workflow.instance,
292292
context_free_payload_codec=self._data_converter.payload_codec,
293+
workflow_context_payload_codec=data_converter.payload_codec,
293294
)
294295
await temporalio.bridge.worker.decode_activation(
295296
act,
@@ -362,9 +363,11 @@ async def _handle_activation(
362363

363364
# Encode completion
364365
if self._data_converter.payload_codec and workflow:
366+
assert data_converter.payload_codec
365367
payload_codec = _CommandAwarePayloadCodec(
366368
workflow.instance,
367369
context_free_payload_codec=self._data_converter.payload_codec,
370+
workflow_context_payload_codec=data_converter.payload_codec,
368371
)
369372
try:
370373
await temporalio.bridge.worker.encode_completion(
@@ -731,6 +734,7 @@ class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec):
731734

732735
instance: WorkflowInstance
733736
context_free_payload_codec: temporalio.converter.PayloadCodec
737+
workflow_context_payload_codec: temporalio.converter.PayloadCodec
734738

735739
async def encode(
736740
self,
@@ -747,6 +751,7 @@ async def decode(
747751
def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec:
748752
return self.instance.get_payload_codec_with_context(
749753
self.context_free_payload_codec,
754+
self.workflow_context_payload_codec,
750755
temporalio.bridge._visitor.current_command_seq.get(),
751756
)
752757

temporalio/worker/_workflow_instance.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,17 @@ def activate(
171171
@abstractmethod
172172
def get_payload_codec_with_context(
173173
self,
174-
payload_codec: temporalio.converter.PayloadCodec,
174+
base_payload_codec: temporalio.converter.PayloadCodec,
175+
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
175176
command_seq: Optional[int],
176177
) -> temporalio.converter.PayloadCodec:
177178
"""Return a payload codec with appropriate serialization context.
178179
179180
Args:
181+
base_payload_codec: The base payload codec to apply context to.
182+
workflow_context_payload_codec: A payload codec that already has workflow context set.
180183
command_seq: Optional sequence number of the associated command. If set, the payload
181-
codec will have serialization context set appropriately for that command.
184+
codec will have serialization context set appropriately for that command.
182185
183186
Returns:
184187
The payload codec.
@@ -2103,68 +2106,71 @@ def _converters_with_context(
21032106

21042107
def get_payload_codec_with_context(
21052108
self,
2106-
payload_codec: temporalio.converter.PayloadCodec,
2109+
base_payload_codec: temporalio.converter.PayloadCodec,
2110+
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
21072111
command_seq: Optional[int],
21082112
) -> temporalio.converter.PayloadCodec:
21092113
if not isinstance(
2110-
payload_codec,
2114+
base_payload_codec,
21112115
temporalio.converter.WithSerializationContext,
21122116
):
2113-
return payload_codec
2114-
2115-
workflow_context = temporalio.converter.WorkflowSerializationContext(
2116-
namespace=self._info.namespace,
2117-
workflow_id=self._info.workflow_id,
2118-
)
2117+
return base_payload_codec
21192118

21202119
if command_seq is None:
21212120
# Use payload codec with workflow context by default (i.e. for payloads not associated
21222121
# with a pending command)
2123-
return payload_codec.with_context(workflow_context)
2122+
return workflow_context_payload_codec
21242123

21252124
if command_seq in self._pending_activities:
2126-
act_handle = self._pending_activities[command_seq]
2127-
act_context = temporalio.converter.ActivitySerializationContext(
2128-
namespace=self._info.namespace,
2129-
workflow_id=self._info.workflow_id,
2130-
workflow_type=self._info.workflow_type,
2131-
activity_type=act_handle._input.activity,
2132-
activity_task_queue=(
2133-
act_handle._input.task_queue
2134-
if isinstance(act_handle._input, StartActivityInput)
2135-
and act_handle._input.task_queue
2136-
else self._info.task_queue
2137-
),
2138-
is_local=isinstance(act_handle._input, StartLocalActivityInput),
2125+
# Use the activity's context
2126+
activity_handle = self._pending_activities[command_seq]
2127+
return base_payload_codec.with_context(
2128+
temporalio.converter.ActivitySerializationContext(
2129+
namespace=self._info.namespace,
2130+
workflow_id=self._info.workflow_id,
2131+
workflow_type=self._info.workflow_type,
2132+
activity_type=activity_handle._input.activity,
2133+
activity_task_queue=(
2134+
activity_handle._input.task_queue
2135+
if isinstance(activity_handle._input, StartActivityInput)
2136+
and activity_handle._input.task_queue
2137+
else self._info.task_queue
2138+
),
2139+
is_local=isinstance(
2140+
activity_handle._input, StartLocalActivityInput
2141+
),
2142+
)
21392143
)
2140-
return payload_codec.with_context(act_context)
21412144

21422145
elif command_seq in self._pending_child_workflows:
2143-
cwf_handle = self._pending_child_workflows[command_seq]
2144-
wf_context = temporalio.converter.WorkflowSerializationContext(
2145-
namespace=self._info.namespace,
2146-
workflow_id=cwf_handle._input.id,
2146+
# Use the child workflow's context
2147+
child_wf_handle = self._pending_child_workflows[command_seq]
2148+
return base_payload_codec.with_context(
2149+
temporalio.converter.WorkflowSerializationContext(
2150+
namespace=self._info.namespace,
2151+
workflow_id=child_wf_handle._input.id,
2152+
)
21472153
)
2148-
return payload_codec.with_context(wf_context)
21492154

21502155
elif command_seq in self._pending_external_signals:
2151-
# Use the target workflow's context for external signals
2152-
_, workflow_id = self._pending_external_signals[command_seq]
2153-
wf_context = temporalio.converter.WorkflowSerializationContext(
2154-
namespace=self._info.namespace,
2155-
workflow_id=workflow_id,
2156+
# Use the target workflow's context
2157+
_, target_workflow_id = self._pending_external_signals[command_seq]
2158+
return base_payload_codec.with_context(
2159+
temporalio.converter.WorkflowSerializationContext(
2160+
namespace=self._info.namespace,
2161+
workflow_id=target_workflow_id,
2162+
)
21562163
)
2157-
return payload_codec.with_context(wf_context)
21582164

21592165
elif command_seq in self._pending_nexus_operations:
21602166
# Use empty context for nexus operations: users will never want to encrypt using a
21612167
# key derived from caller workflow context because the caller workflow context is
21622168
# not available on the handler side for decryption.
2163-
return payload_codec
2169+
return base_payload_codec
21642170

21652171
else:
21662172
# Use payload codec with workflow context for all other payloads
2167-
return payload_codec.with_context(workflow_context)
2173+
return workflow_context_payload_codec
21682174

21692175
def _instantiate_workflow_object(self) -> Any:
21702176
if not self._workflow_input:

temporalio/worker/workflow_sandbox/_in_sandbox.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,13 @@ def activate(
8383

8484
def get_payload_codec_with_context(
8585
self,
86-
payload_codec: temporalio.converter.PayloadCodec,
86+
base_payload_codec: temporalio.converter.PayloadCodec,
87+
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
8788
command_seq: Optional[int],
8889
) -> temporalio.converter.PayloadCodec:
8990
"""Get payload codec with context."""
90-
return self.instance.get_payload_codec_with_context(payload_codec, command_seq)
91+
return self.instance.get_payload_codec_with_context(
92+
base_payload_codec,
93+
workflow_context_payload_codec,
94+
command_seq,
95+
)

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,19 @@ def get_thread_id(self) -> Optional[int]:
188188

189189
def get_payload_codec_with_context(
190190
self,
191-
payload_codec: temporalio.converter.PayloadCodec,
191+
base_payload_codec: temporalio.converter.PayloadCodec,
192+
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
192193
command_seq: Optional[int],
193194
) -> temporalio.converter.PayloadCodec:
194195
# Forward call to the sandboxed instance
195196
self.importer.restriction_context.is_runtime = True
196197
try:
197198
self._run_code(
198199
"with __temporal_importer.applied():\n"
199-
" __temporal_codec = __temporal_in_sandbox.get_payload_codec_with_context(__temporal_payload_codec, __temporal_command_seq)\n",
200+
" __temporal_codec = __temporal_in_sandbox.get_payload_codec_with_context(__temporal_base_payload_codec, __temporal_workflow_context_payload_codec, __temporal_command_seq)\n",
200201
__temporal_importer=self.importer,
201-
__temporal_payload_codec=payload_codec,
202+
__temporal_base_payload_codec=base_payload_codec,
203+
__temporal_workflow_context_payload_codec=workflow_context_payload_codec,
202204
__temporal_command_seq=command_seq,
203205
)
204206
return self.globals_and_locals.pop("__temporal_codec", None) # type: ignore

tests/worker/test_workflow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,11 +1612,14 @@ def activate(self, act: WorkflowActivation) -> WorkflowActivationCompletion:
16121612

16131613
def get_payload_codec_with_context(
16141614
self,
1615-
payload_codec: temporalio.converter.PayloadCodec,
1615+
base_payload_codec: temporalio.converter.PayloadCodec,
1616+
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
16161617
command_seq: Optional[int],
16171618
) -> temporalio.converter.PayloadCodec:
16181619
return self._unsandboxed.get_payload_codec_with_context(
1619-
payload_codec, command_seq
1620+
base_payload_codec,
1621+
workflow_context_payload_codec,
1622+
command_seq,
16201623
)
16211624

16221625

0 commit comments

Comments
 (0)