Skip to content

Commit fcd40ef

Browse files
committed
Command-aware serialization context during payload visitor traversal
- Change visitor API - Add command seq IDs to async context during payload visitor traversal - Store data_converter on WorkflowInstanceDetails instead of converter classes - Command-aware codec for payload visitor
1 parent 56dc0ab commit fcd40ef

File tree

11 files changed

+738
-128
lines changed

11 files changed

+738
-128
lines changed

scripts/gen_payload_visitor.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import subprocess
22
import sys
33
from pathlib import Path
4-
from typing import Optional, Tuple
4+
from typing import Optional
55

66
from google.protobuf.descriptor import Descriptor, FieldDescriptor
77

@@ -62,6 +62,20 @@ def emit_singular(
6262
await self._visit_{child_method}(fs, {access_expr})"""
6363

6464

65+
def emit_singular_with_seq(
66+
field_name: str, access_expr: str, child_method: str, presence_word: str
67+
) -> str:
68+
# Helper to emit a singular field visit that sets the seq contextvar, with presence check but
69+
# without headers guard since this is used for commands only.
70+
return f"""\
71+
{presence_word} o.HasField("{field_name}"):
72+
token = current_command_seq.set({access_expr}.seq)
73+
try:
74+
await self._visit_{child_method}(fs, {access_expr})
75+
finally:
76+
current_command_seq.reset(token)"""
77+
78+
6579
class VisitorGenerator:
6680
def generate(self, roots: list[Descriptor]) -> str:
6781
"""
@@ -80,10 +94,16 @@ def generate(self, roots: list[Descriptor]) -> str:
8094
header = """
8195
# This file is generated by gen_payload_visitor.py. Changes should be made there.
8296
import abc
83-
from typing import Any, MutableSequence
97+
import contextvars
98+
from typing import Any, MutableSequence, Optional
8499
85100
from temporalio.api.common.v1.message_pb2 import Payload
86101
102+
# Current workflow command sequence number
103+
current_command_seq: contextvars.ContextVar[Optional[int]] = contextvars.ContextVar(
104+
"current_command_seq", default=None
105+
)
106+
87107
class VisitorFunctions(abc.ABC):
88108
\"\"\"Set of functions which can be called by the visitor.
89109
Allows handling payloads as a sequence.
@@ -253,6 +273,29 @@ def walk(self, desc: Descriptor) -> bool:
253273
)
254274
)
255275

276+
commands_with_seq = {
277+
"start_timer",
278+
"cancel_timer",
279+
"schedule_activity",
280+
"schedule_local_activity",
281+
"request_cancel_activity",
282+
"request_cancel_local_activity",
283+
"start_child_workflow_execution",
284+
"request_cancel_external_workflow_execution",
285+
"signal_external_workflow_execution",
286+
"cancel_signal_workflow",
287+
"schedule_nexus_operation",
288+
"request_cancel_nexus_operation",
289+
}
290+
activation_jobs_with_seq = {
291+
"resolve_activity",
292+
"resolve_child_workflow_execution",
293+
"resolve_child_workflow_execution_start",
294+
"resolve_request_cancel_external_workflow",
295+
"resolve_signal_external_workflow",
296+
"resolve_nexus_operation",
297+
"resolve_nexus_operation_start",
298+
}
256299
# Process oneof fields as if/elif chains
257300
for oneof_idx, fields in oneof_fields.items():
258301
oneof_lines = []
@@ -264,9 +307,25 @@ def walk(self, desc: Descriptor) -> bool:
264307
if child_has_payload:
265308
if_word = "if" if first else "elif"
266309
first = False
267-
line = emit_singular(
268-
field.name, f"o.{field.name}", name_for(child_desc), if_word
269-
)
310+
if (
311+
desc.full_name == "coresdk.workflow_commands.WorkflowCommand"
312+
and field.name in commands_with_seq
313+
):
314+
line = emit_singular_with_seq(
315+
field.name, f"o.{field.name}", name_for(child_desc), if_word
316+
)
317+
elif (
318+
desc.full_name
319+
== "coresdk.workflow_activation.WorkflowActivationJob"
320+
and field.name in activation_jobs_with_seq
321+
):
322+
line = emit_singular_with_seq(
323+
field.name, f"o.{field.name}", name_for(child_desc), if_word
324+
)
325+
else:
326+
line = emit_singular(
327+
field.name, f"o.{field.name}", name_for(child_desc), if_word
328+
)
270329
oneof_lines.append(line)
271330
if oneof_lines:
272331
lines.extend(oneof_lines)

temporalio/bridge/_visitor.py

Lines changed: 95 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
# This file is generated by gen_payload_visitor.py. Changes should be made there.
22
import abc
3-
from typing import Any, MutableSequence
3+
import contextvars
4+
from typing import Any, MutableSequence, Optional
45

56
from temporalio.api.common.v1.message_pb2 import Payload
67

8+
# Current workflow command sequence number
9+
current_command_seq: contextvars.ContextVar[Optional[int]] = contextvars.ContextVar(
10+
"current_command_seq", default=None
11+
)
12+
713

814
class VisitorFunctions(abc.ABC):
915
"""Set of functions which can be called by the visitor.
@@ -247,35 +253,69 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o):
247253
fs, o.signal_workflow
248254
)
249255
elif o.HasField("resolve_activity"):
250-
await self._visit_coresdk_workflow_activation_ResolveActivity(
251-
fs, o.resolve_activity
252-
)
256+
token = current_command_seq.set(o.resolve_activity.seq)
257+
try:
258+
await self._visit_coresdk_workflow_activation_ResolveActivity(
259+
fs, o.resolve_activity
260+
)
261+
finally:
262+
current_command_seq.reset(token)
253263
elif o.HasField("resolve_child_workflow_execution_start"):
254-
await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(
255-
fs, o.resolve_child_workflow_execution_start
264+
token = current_command_seq.set(
265+
o.resolve_child_workflow_execution_start.seq
256266
)
267+
try:
268+
await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(
269+
fs, o.resolve_child_workflow_execution_start
270+
)
271+
finally:
272+
current_command_seq.reset(token)
257273
elif o.HasField("resolve_child_workflow_execution"):
258-
await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(
259-
fs, o.resolve_child_workflow_execution
260-
)
274+
token = current_command_seq.set(o.resolve_child_workflow_execution.seq)
275+
try:
276+
await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(
277+
fs, o.resolve_child_workflow_execution
278+
)
279+
finally:
280+
current_command_seq.reset(token)
261281
elif o.HasField("resolve_signal_external_workflow"):
262-
await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(
263-
fs, o.resolve_signal_external_workflow
264-
)
282+
token = current_command_seq.set(o.resolve_signal_external_workflow.seq)
283+
try:
284+
await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(
285+
fs, o.resolve_signal_external_workflow
286+
)
287+
finally:
288+
current_command_seq.reset(token)
265289
elif o.HasField("resolve_request_cancel_external_workflow"):
266-
await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(
267-
fs, o.resolve_request_cancel_external_workflow
290+
token = current_command_seq.set(
291+
o.resolve_request_cancel_external_workflow.seq
268292
)
293+
try:
294+
await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(
295+
fs, o.resolve_request_cancel_external_workflow
296+
)
297+
finally:
298+
current_command_seq.reset(token)
269299
elif o.HasField("do_update"):
270300
await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update)
271301
elif o.HasField("resolve_nexus_operation_start"):
272-
await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart(
273-
fs, o.resolve_nexus_operation_start
274-
)
302+
token = current_command_seq.set(o.resolve_nexus_operation_start.seq)
303+
try:
304+
await (
305+
self._visit_coresdk_workflow_activation_ResolveNexusOperationStart(
306+
fs, o.resolve_nexus_operation_start
307+
)
308+
)
309+
finally:
310+
current_command_seq.reset(token)
275311
elif o.HasField("resolve_nexus_operation"):
276-
await self._visit_coresdk_workflow_activation_ResolveNexusOperation(
277-
fs, o.resolve_nexus_operation
278-
)
312+
token = current_command_seq.set(o.resolve_nexus_operation.seq)
313+
try:
314+
await self._visit_coresdk_workflow_activation_ResolveNexusOperation(
315+
fs, o.resolve_nexus_operation
316+
)
317+
finally:
318+
current_command_seq.reset(token)
279319

280320
async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o):
281321
for v in o.jobs:
@@ -371,9 +411,13 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
371411
if o.HasField("user_metadata"):
372412
await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata)
373413
if o.HasField("schedule_activity"):
374-
await self._visit_coresdk_workflow_commands_ScheduleActivity(
375-
fs, o.schedule_activity
376-
)
414+
token = current_command_seq.set(o.schedule_activity.seq)
415+
try:
416+
await self._visit_coresdk_workflow_commands_ScheduleActivity(
417+
fs, o.schedule_activity
418+
)
419+
finally:
420+
current_command_seq.reset(token)
377421
elif o.HasField("respond_to_query"):
378422
await self._visit_coresdk_workflow_commands_QueryResult(
379423
fs, o.respond_to_query
@@ -391,17 +435,29 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
391435
fs, o.continue_as_new_workflow_execution
392436
)
393437
elif o.HasField("start_child_workflow_execution"):
394-
await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(
395-
fs, o.start_child_workflow_execution
396-
)
438+
token = current_command_seq.set(o.start_child_workflow_execution.seq)
439+
try:
440+
await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution(
441+
fs, o.start_child_workflow_execution
442+
)
443+
finally:
444+
current_command_seq.reset(token)
397445
elif o.HasField("signal_external_workflow_execution"):
398-
await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
399-
fs, o.signal_external_workflow_execution
400-
)
446+
token = current_command_seq.set(o.signal_external_workflow_execution.seq)
447+
try:
448+
await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(
449+
fs, o.signal_external_workflow_execution
450+
)
451+
finally:
452+
current_command_seq.reset(token)
401453
elif o.HasField("schedule_local_activity"):
402-
await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(
403-
fs, o.schedule_local_activity
404-
)
454+
token = current_command_seq.set(o.schedule_local_activity.seq)
455+
try:
456+
await self._visit_coresdk_workflow_commands_ScheduleLocalActivity(
457+
fs, o.schedule_local_activity
458+
)
459+
finally:
460+
current_command_seq.reset(token)
405461
elif o.HasField("upsert_workflow_search_attributes"):
406462
await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(
407463
fs, o.upsert_workflow_search_attributes
@@ -415,9 +471,13 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o):
415471
fs, o.update_response
416472
)
417473
elif o.HasField("schedule_nexus_operation"):
418-
await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(
419-
fs, o.schedule_nexus_operation
420-
)
474+
token = current_command_seq.set(o.schedule_nexus_operation.seq)
475+
try:
476+
await self._visit_coresdk_workflow_commands_ScheduleNexusOperation(
477+
fs, o.schedule_nexus_operation
478+
)
479+
finally:
480+
current_command_seq.reset(token)
421481

422482
async def _visit_coresdk_workflow_completion_Success(self, fs, o):
423483
for v in o.commands:

temporalio/bridge/worker.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88
from dataclasses import dataclass
99
from typing import (
10-
TYPE_CHECKING,
1110
Awaitable,
1211
Callable,
1312
List,
14-
Mapping,
1513
MutableSequence,
1614
Optional,
1715
Sequence,
@@ -20,7 +18,6 @@
2018
Union,
2119
)
2220

23-
import google.protobuf.internal.containers
2421
from typing_extensions import TypeAlias
2522

2623
import temporalio.api.common.v1
@@ -35,7 +32,7 @@
3532
import temporalio.bridge.temporal_sdk_bridge
3633
import temporalio.converter
3734
import temporalio.exceptions
38-
from temporalio.api.common.v1.message_pb2 import Payload, Payloads
35+
from temporalio.api.common.v1.message_pb2 import Payload
3936
from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions
4037
from temporalio.bridge.temporal_sdk_bridge import (
4138
CustomSlotSupplier as BridgeCustomSlotSupplier,
@@ -299,22 +296,22 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
299296

300297

301298
async def decode_activation(
302-
act: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
303-
codec: temporalio.converter.PayloadCodec,
299+
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
300+
decode: Callable[[Sequence[Payload]], Awaitable[List[Payload]]],
304301
decode_headers: bool,
305302
) -> None:
306-
"""Decode the given activation with the codec."""
303+
"""Decode all payloads in the activation."""
307304
await PayloadVisitor(
308305
skip_search_attributes=True, skip_headers=not decode_headers
309-
).visit(_Visitor(codec.decode), act)
306+
).visit(_Visitor(decode), activation)
310307

311308

312309
async def encode_completion(
313-
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
314-
codec: temporalio.converter.PayloadCodec,
310+
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
311+
encode: Callable[[Sequence[Payload]], Awaitable[List[Payload]]],
315312
encode_headers: bool,
316313
) -> None:
317-
"""Recursively encode the given completion with the codec."""
314+
"""Encode all payloads in the completion."""
318315
await PayloadVisitor(
319316
skip_search_attributes=True, skip_headers=not encode_headers
320-
).visit(_Visitor(codec.encode), comp)
317+
).visit(_Visitor(encode), completion)

temporalio/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5974,6 +5974,7 @@ async def _populate_start_workflow_execution_request(
59745974
req.workflow_type.name = input.workflow
59755975
req.task_queue.name = input.task_queue
59765976
if input.args:
5977+
# client encode wf input
59775978
req.input.payloads.extend(await data_converter.encode(input.args))
59785979
if input.execution_timeout is not None:
59795980
req.workflow_execution_timeout.FromTimedelta(input.execution_timeout)

0 commit comments

Comments
 (0)