Skip to content

Commit 4a744fa

Browse files
committed
Compute context, not payload codec
1 parent fc3ddf7 commit 4a744fa

File tree

5 files changed

+52
-80
lines changed

5 files changed

+52
-80
lines changed

temporalio/worker/_workflow.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,6 @@ async def _handle_activation(
290290
payload_codec = _CommandAwarePayloadCodec(
291291
workflow.instance,
292292
context_free_payload_codec=self._data_converter.payload_codec,
293-
workflow_context_payload_codec=data_converter.payload_codec,
294293
)
295294
await temporalio.bridge.worker.decode_activation(
296295
act,
@@ -367,7 +366,6 @@ async def _handle_activation(
367366
payload_codec = _CommandAwarePayloadCodec(
368367
workflow.instance,
369368
context_free_payload_codec=self._data_converter.payload_codec,
370-
workflow_context_payload_codec=data_converter.payload_codec,
371369
)
372370
try:
373371
await temporalio.bridge.worker.encode_completion(
@@ -734,7 +732,6 @@ class _CommandAwarePayloadCodec(temporalio.converter.PayloadCodec):
734732

735733
instance: WorkflowInstance
736734
context_free_payload_codec: temporalio.converter.PayloadCodec
737-
workflow_context_payload_codec: temporalio.converter.PayloadCodec
738735

739736
async def encode(
740737
self,
@@ -749,11 +746,18 @@ async def decode(
749746
return await self._get_current_command_codec().decode(payloads)
750747

751748
def _get_current_command_codec(self) -> temporalio.converter.PayloadCodec:
752-
return self.instance.get_payload_codec_with_context(
749+
if not isinstance(
753750
self.context_free_payload_codec,
754-
self.workflow_context_payload_codec,
751+
temporalio.converter.WithSerializationContext,
752+
):
753+
return self.context_free_payload_codec
754+
755+
if context := self.instance.get_serialization_context(
755756
temporalio.bridge._visitor.current_command_info.get(),
756-
)
757+
):
758+
return self.context_free_payload_codec.with_context(context)
759+
760+
return self.context_free_payload_codec
757761

758762

759763
class _InterruptDeadlockError(BaseException):

temporalio/worker/_workflow_instance.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -170,22 +170,18 @@ def activate(
170170
raise NotImplementedError
171171

172172
@abstractmethod
173-
def get_payload_codec_with_context(
173+
def get_serialization_context(
174174
self,
175-
base_payload_codec: temporalio.converter.PayloadCodec,
176-
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
177175
command_info: Optional[temporalio.bridge._visitor.CommandInfo],
178-
) -> temporalio.converter.PayloadCodec:
179-
"""Return a payload codec with appropriate serialization context.
176+
) -> Optional[temporalio.converter.SerializationContext]:
177+
"""Return appropriate serialization context.
180178
181179
Args:
182-
base_payload_codec: The base payload codec to apply context to.
183-
workflow_context_payload_codec: A payload codec that already has workflow context set.
184180
command_info: Optional information identifying the associated command. If set, the payload
185181
codec will have serialization context set appropriately for that command.
186182
187183
Returns:
188-
The payload codec.
184+
The serialization context, or None if no context should be set.
189185
"""
190186
raise NotImplementedError
191187

@@ -2116,22 +2112,18 @@ def _converters_with_context(
21162112
failure_converter = failure_converter.with_context(context)
21172113
return payload_converter, failure_converter
21182114

2119-
def get_payload_codec_with_context(
2115+
def get_serialization_context(
21202116
self,
2121-
base_payload_codec: temporalio.converter.PayloadCodec,
2122-
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
21232117
command_info: Optional[temporalio.bridge._visitor.CommandInfo],
2124-
) -> temporalio.converter.PayloadCodec:
2125-
if not isinstance(
2126-
base_payload_codec,
2127-
temporalio.converter.WithSerializationContext,
2128-
):
2129-
return base_payload_codec
2130-
2118+
) -> Optional[temporalio.converter.SerializationContext]:
2119+
workflow_context = temporalio.converter.WorkflowSerializationContext(
2120+
namespace=self._info.namespace,
2121+
workflow_id=self._info.workflow_id,
2122+
)
21312123
if command_info is None:
21322124
# Use payload codec with workflow context by default (i.e. for payloads not associated
21332125
# with a pending command)
2134-
return workflow_context_payload_codec
2126+
return workflow_context
21352127

21362128
if (
21372129
command_info.command_type
@@ -2140,22 +2132,18 @@ def get_payload_codec_with_context(
21402132
):
21412133
# Use the activity's context
21422134
activity_handle = self._pending_activities[command_info.command_seq]
2143-
return base_payload_codec.with_context(
2144-
temporalio.converter.ActivitySerializationContext(
2145-
namespace=self._info.namespace,
2146-
workflow_id=self._info.workflow_id,
2147-
workflow_type=self._info.workflow_type,
2148-
activity_type=activity_handle._input.activity,
2149-
activity_task_queue=(
2150-
activity_handle._input.task_queue
2151-
if isinstance(activity_handle._input, StartActivityInput)
2152-
and activity_handle._input.task_queue
2153-
else self._info.task_queue
2154-
),
2155-
is_local=isinstance(
2156-
activity_handle._input, StartLocalActivityInput
2157-
),
2158-
)
2135+
return temporalio.converter.ActivitySerializationContext(
2136+
namespace=self._info.namespace,
2137+
workflow_id=self._info.workflow_id,
2138+
workflow_type=self._info.workflow_type,
2139+
activity_type=activity_handle._input.activity,
2140+
activity_task_queue=(
2141+
activity_handle._input.task_queue
2142+
if isinstance(activity_handle._input, StartActivityInput)
2143+
and activity_handle._input.task_queue
2144+
else self._info.task_queue
2145+
),
2146+
is_local=isinstance(activity_handle._input, StartLocalActivityInput),
21592147
)
21602148

21612149
elif (
@@ -2165,11 +2153,9 @@ def get_payload_codec_with_context(
21652153
):
21662154
# Use the child workflow's context
21672155
child_wf_handle = self._pending_child_workflows[command_info.command_seq]
2168-
return base_payload_codec.with_context(
2169-
temporalio.converter.WorkflowSerializationContext(
2170-
namespace=self._info.namespace,
2171-
workflow_id=child_wf_handle._input.id,
2172-
)
2156+
return temporalio.converter.WorkflowSerializationContext(
2157+
namespace=self._info.namespace,
2158+
workflow_id=child_wf_handle._input.id,
21732159
)
21742160

21752161
elif (
@@ -2181,11 +2167,9 @@ def get_payload_codec_with_context(
21812167
_, target_workflow_id = self._pending_external_signals[
21822168
command_info.command_seq
21832169
]
2184-
return base_payload_codec.with_context(
2185-
temporalio.converter.WorkflowSerializationContext(
2186-
namespace=self._info.namespace,
2187-
workflow_id=target_workflow_id,
2188-
)
2170+
return temporalio.converter.WorkflowSerializationContext(
2171+
namespace=self._info.namespace,
2172+
workflow_id=target_workflow_id,
21892173
)
21902174

21912175
elif (
@@ -2196,11 +2180,11 @@ def get_payload_codec_with_context(
21962180
# Use empty context for nexus operations: users will never want to encrypt using a
21972181
# key derived from caller workflow context because the caller workflow context is
21982182
# not available on the handler side for decryption.
2199-
return base_payload_codec
2183+
return None
22002184

22012185
else:
22022186
# Use payload codec with workflow context for all other payloads
2203-
return workflow_context_payload_codec
2187+
return workflow_context
22042188

22052189
def _instantiate_workflow_object(self) -> Any:
22062190
if not self._workflow_input:

temporalio/worker/workflow_sandbox/_in_sandbox.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,9 @@ def activate(
8282
"""Send activation to this instance."""
8383
return self.instance.activate(act)
8484

85-
def get_payload_codec_with_context(
85+
def get_serialization_context(
8686
self,
87-
base_payload_codec: temporalio.converter.PayloadCodec,
88-
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
8987
command_info: Optional[temporalio.bridge._visitor.CommandInfo],
90-
) -> temporalio.converter.PayloadCodec:
91-
"""Get payload codec with context."""
92-
return self.instance.get_payload_codec_with_context(
93-
base_payload_codec,
94-
workflow_context_payload_codec,
95-
command_info,
96-
)
88+
) -> Optional[temporalio.converter.SerializationContext]:
89+
"""Get serialization context."""
90+
return self.instance.get_serialization_context(command_info)

temporalio/worker/workflow_sandbox/_runner.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,23 +187,19 @@ def _run_code(self, code: str, **extra_globals: Any) -> None:
187187
def get_thread_id(self) -> Optional[int]:
188188
return self._current_thread_id
189189

190-
def get_payload_codec_with_context(
190+
def get_serialization_context(
191191
self,
192-
base_payload_codec: temporalio.converter.PayloadCodec,
193-
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
194192
command_info: Optional[temporalio.bridge._visitor.CommandInfo],
195-
) -> temporalio.converter.PayloadCodec:
193+
) -> Optional[temporalio.converter.SerializationContext]:
196194
# Forward call to the sandboxed instance
197195
self.importer.restriction_context.is_runtime = True
198196
try:
199197
self._run_code(
200198
"with __temporal_importer.applied():\n"
201-
" __temporal_codec = __temporal_in_sandbox.get_payload_codec_with_context(__temporal_base_payload_codec, __temporal_workflow_context_payload_codec, __temporal_command_info)\n",
199+
" __temporal_context = __temporal_in_sandbox.get_serialization_context(__temporal_command_info)\n",
202200
__temporal_importer=self.importer,
203-
__temporal_base_payload_codec=base_payload_codec,
204-
__temporal_workflow_context_payload_codec=workflow_context_payload_codec,
205201
__temporal_command_info=command_info,
206202
)
207-
return self.globals_and_locals.pop("__temporal_codec", None) # type: ignore
203+
return self.globals_and_locals.pop("__temporal_context", None) # type: ignore
208204
finally:
209205
self.importer.restriction_context.is_runtime = False

tests/worker/test_workflow.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,17 +1611,11 @@ def activate(self, act: WorkflowActivation) -> WorkflowActivationCompletion:
16111611
self._runner._pairs.append((act, comp))
16121612
return comp
16131613

1614-
def get_payload_codec_with_context(
1614+
def get_serialization_context(
16151615
self,
1616-
base_payload_codec: temporalio.converter.PayloadCodec,
1617-
workflow_context_payload_codec: temporalio.converter.PayloadCodec,
16181616
command_info: Optional[temporalio.bridge._visitor.CommandInfo],
1619-
) -> temporalio.converter.PayloadCodec:
1620-
return self._unsandboxed.get_payload_codec_with_context(
1621-
base_payload_codec,
1622-
workflow_context_payload_codec,
1623-
command_info,
1624-
)
1617+
) -> Optional[temporalio.converter.SerializationContext]:
1618+
return self._unsandboxed.get_serialization_context(command_info)
16251619

16261620

16271621
async def test_workflow_with_custom_runner(client: Client):

0 commit comments

Comments
 (0)