Skip to content

Commit 2407569

Browse files
committed
Reflection based visitor
1 parent 251084d commit 2407569

File tree

4 files changed

+205
-142
lines changed

4 files changed

+205
-142
lines changed

temporalio/bridge/visitor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Awaitable, Callable, Any
2+
3+
from collections.abc import Mapping as AbcMapping, Sequence as AbcSequence
4+
5+
from google.protobuf.descriptor import FieldDescriptor
6+
from google.protobuf.message import Message
7+
8+
from temporalio.api.common.v1.message_pb2 import Payload
9+
10+
11+
async def visit_payloads(
12+
f: Callable[[Payload], Awaitable[Payload]], root: Any
13+
) -> None:
14+
print("Visiting object: ", type(root))
15+
if isinstance(root, Payload):
16+
print("Applying to payload: ", root)
17+
root.CopyFrom(await f(root))
18+
print("Applied to payload: ", root)
19+
elif isinstance(root, AbcMapping):
20+
for k, v in root.items():
21+
await visit_payloads(f, k)
22+
await visit_payloads(f, v)
23+
elif isinstance(root, AbcSequence) and not isinstance(
24+
root, (bytes, bytearray, str)
25+
):
26+
for o in root:
27+
await visit_payloads(f, o)
28+
elif isinstance(root, Message):
29+
await visit_message(f, root)
30+
31+
32+
async def visit_message(
33+
f: Callable[[Payload], Awaitable[Payload]], root: Message
34+
) -> None:
35+
print("Visiting Message: ", type(root))
36+
for field in root.DESCRIPTOR.fields:
37+
print("Evaluating Field: ", field.name)
38+
39+
# Repeated fields (including maps which are represented as repeated messages)
40+
if field.label == FieldDescriptor.LABEL_REPEATED:
41+
value = getattr(root, field.name)
42+
if field.message_type is not None and field.message_type.GetOptions().map_entry:
43+
for k, v in value.items():
44+
await visit_payloads(f, k)
45+
await visit_payloads(f, v)
46+
else:
47+
for item in value:
48+
await visit_payloads(f, item)
49+
else:
50+
# Only descend into singular message fields if present
51+
if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name):
52+
value = getattr(root, field.name)
53+
await visit_payloads(f, value)

temporalio/bridge/worker.py

Lines changed: 14 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121

2222
import google.protobuf.internal.containers
23+
from google.protobuf.message import Message
2324
from typing_extensions import TypeAlias
2425

2526
import temporalio.api.common.v1
@@ -39,6 +40,9 @@
3940
)
4041
from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore
4142

43+
from temporalio.api.common.v1.message_pb2 import Payload
44+
from temporalio.bridge.visitor import visit_payloads, visit_message
45+
4246

4347
@dataclass
4448
class WorkerConfig:
@@ -368,15 +372,9 @@ async def _encode_payloads(
368372
codec: temporalio.converter.PayloadCodec,
369373
) -> None:
370374
"""Encode payloads with the given codec."""
371-
return await _apply_to_payloads(payloads, codec.encode)
372-
373-
374-
async def _encode_payload(
375-
payload: temporalio.api.common.v1.Payload,
376-
codec: temporalio.converter.PayloadCodec,
377-
) -> None:
378-
"""Decode a payload with the given codec."""
379-
return await _apply_to_payload(payload, codec.encode)
375+
async def visitor(payload: Payload) -> Payload:
376+
return (await codec.encode([payload]))[0]
377+
return await visit_payloads(visitor, payloads)
380378

381379

382380
async def decode_activation(
@@ -385,144 +383,18 @@ async def decode_activation(
385383
decode_headers: bool,
386384
) -> None:
387385
"""Decode the given activation with the codec."""
388-
for job in act.jobs:
389-
if job.HasField("query_workflow"):
390-
await _decode_payloads(job.query_workflow.arguments, codec)
391-
if decode_headers:
392-
await _decode_headers(job.query_workflow.headers, codec)
393-
elif job.HasField("resolve_activity"):
394-
if job.resolve_activity.result.HasField("cancelled"):
395-
await codec.decode_failure(
396-
job.resolve_activity.result.cancelled.failure
397-
)
398-
elif job.resolve_activity.result.HasField("completed"):
399-
if job.resolve_activity.result.completed.HasField("result"):
400-
await _decode_payload(
401-
job.resolve_activity.result.completed.result, codec
402-
)
403-
elif job.resolve_activity.result.HasField("failed"):
404-
await codec.decode_failure(job.resolve_activity.result.failed.failure)
405-
elif job.HasField("resolve_child_workflow_execution"):
406-
if job.resolve_child_workflow_execution.result.HasField("cancelled"):
407-
await codec.decode_failure(
408-
job.resolve_child_workflow_execution.result.cancelled.failure
409-
)
410-
elif job.resolve_child_workflow_execution.result.HasField(
411-
"completed"
412-
) and job.resolve_child_workflow_execution.result.completed.HasField(
413-
"result"
414-
):
415-
await _decode_payload(
416-
job.resolve_child_workflow_execution.result.completed.result, codec
417-
)
418-
elif job.resolve_child_workflow_execution.result.HasField("failed"):
419-
await codec.decode_failure(
420-
job.resolve_child_workflow_execution.result.failed.failure
421-
)
422-
elif job.HasField("resolve_child_workflow_execution_start"):
423-
if job.resolve_child_workflow_execution_start.HasField("cancelled"):
424-
await codec.decode_failure(
425-
job.resolve_child_workflow_execution_start.cancelled.failure
426-
)
427-
elif job.HasField("resolve_request_cancel_external_workflow"):
428-
if job.resolve_request_cancel_external_workflow.HasField("failure"):
429-
await codec.decode_failure(
430-
job.resolve_request_cancel_external_workflow.failure
431-
)
432-
elif job.HasField("resolve_signal_external_workflow"):
433-
if job.resolve_signal_external_workflow.HasField("failure"):
434-
await codec.decode_failure(job.resolve_signal_external_workflow.failure)
435-
elif job.HasField("signal_workflow"):
436-
await _decode_payloads(job.signal_workflow.input, codec)
437-
if decode_headers:
438-
await _decode_headers(job.signal_workflow.headers, codec)
439-
elif job.HasField("initialize_workflow"):
440-
await _decode_payloads(job.initialize_workflow.arguments, codec)
441-
if decode_headers:
442-
await _decode_headers(job.initialize_workflow.headers, codec)
443-
if job.initialize_workflow.HasField("continued_failure"):
444-
await codec.decode_failure(job.initialize_workflow.continued_failure)
445-
for val in job.initialize_workflow.memo.fields.values():
446-
# This uses API payload not bridge payload
447-
new_payload = (await codec.decode([val]))[0]
448-
# Make a shallow copy, in case new_payload.metadata and val.metadata are
449-
# references to the same memory, e.g. decode() returns its input unchanged.
450-
new_metadata = dict(new_payload.metadata)
451-
val.metadata.clear()
452-
val.metadata.update(new_metadata)
453-
val.data = new_payload.data
454-
elif job.HasField("do_update"):
455-
await _decode_payloads(job.do_update.input, codec)
456-
if decode_headers:
457-
await _decode_headers(job.do_update.headers, codec)
386+
async def visitor(payload: Payload) -> Payload:
387+
return (await codec.decode([payload]))[0]
458388

389+
await visit_message(visitor, act)
459390

460391
async def encode_completion(
461392
comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
462393
codec: temporalio.converter.PayloadCodec,
463394
encode_headers: bool,
464395
) -> None:
465396
"""Recursively encode the given completion with the codec."""
466-
if comp.HasField("failed"):
467-
await codec.encode_failure(comp.failed.failure)
468-
elif comp.HasField("successful"):
469-
for command in comp.successful.commands:
470-
if command.HasField("complete_workflow_execution"):
471-
if command.complete_workflow_execution.HasField("result"):
472-
await _encode_payload(
473-
command.complete_workflow_execution.result, codec
474-
)
475-
elif command.HasField("continue_as_new_workflow_execution"):
476-
await _encode_payloads(
477-
command.continue_as_new_workflow_execution.arguments, codec
478-
)
479-
if encode_headers:
480-
await _encode_headers(
481-
command.continue_as_new_workflow_execution.headers, codec
482-
)
483-
for val in command.continue_as_new_workflow_execution.memo.values():
484-
await _encode_payload(val, codec)
485-
elif command.HasField("fail_workflow_execution"):
486-
await codec.encode_failure(command.fail_workflow_execution.failure)
487-
elif command.HasField("respond_to_query"):
488-
if command.respond_to_query.HasField("failed"):
489-
await codec.encode_failure(command.respond_to_query.failed)
490-
elif command.respond_to_query.HasField(
491-
"succeeded"
492-
) and command.respond_to_query.succeeded.HasField("response"):
493-
await _encode_payload(
494-
command.respond_to_query.succeeded.response, codec
495-
)
496-
elif command.HasField("schedule_activity"):
497-
await _encode_payloads(command.schedule_activity.arguments, codec)
498-
if encode_headers:
499-
await _encode_headers(command.schedule_activity.headers, codec)
500-
elif command.HasField("schedule_local_activity"):
501-
await _encode_payloads(command.schedule_local_activity.arguments, codec)
502-
if encode_headers:
503-
await _encode_headers(
504-
command.schedule_local_activity.headers, codec
505-
)
506-
elif command.HasField("signal_external_workflow_execution"):
507-
await _encode_payloads(
508-
command.signal_external_workflow_execution.args, codec
509-
)
510-
if encode_headers:
511-
await _encode_headers(
512-
command.signal_external_workflow_execution.headers, codec
513-
)
514-
elif command.HasField("start_child_workflow_execution"):
515-
await _encode_payloads(
516-
command.start_child_workflow_execution.input, codec
517-
)
518-
if encode_headers:
519-
await _encode_headers(
520-
command.start_child_workflow_execution.headers, codec
521-
)
522-
for val in command.start_child_workflow_execution.memo.values():
523-
await _encode_payload(val, codec)
524-
elif command.HasField("update_response"):
525-
if command.update_response.HasField("completed"):
526-
await _encode_payload(command.update_response.completed, codec)
527-
elif command.update_response.HasField("rejected"):
528-
await codec.encode_failure(command.update_response.rejected)
397+
async def visitor(payload: Payload) -> Payload:
398+
return (await codec.encode([payload]))[0]
399+
400+
await visit_message(visitor, comp)

tests/worker/test_visitor.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from google.protobuf.duration_pb2 import Duration
2+
3+
from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata
4+
from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import (
5+
WorkflowCommand,
6+
ScheduleActivity,
7+
ScheduleLocalActivity,
8+
ContinueAsNewWorkflowExecution,
9+
StartChildWorkflowExecution,
10+
SignalExternalWorkflowExecution,
11+
UpdateResponse,
12+
)
13+
from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import (
14+
Success,
15+
WorkflowActivationCompletion,
16+
)
17+
from temporalio.bridge.visitor import visit_message
18+
from temporalio.api.common.v1.message_pb2 import Payload, Priority
19+
20+
21+
async def test_visit_payloads_mutates_all_payloads_in_message():
22+
comp = WorkflowActivationCompletion(
23+
run_id="1",
24+
successful=Success(
25+
commands=[
26+
WorkflowCommand(
27+
schedule_activity=ScheduleActivity(
28+
seq=1,
29+
activity_id="1",
30+
activity_type="",
31+
task_queue="",
32+
headers={"foo": Payload(data=b"bar")},
33+
arguments=[Payload(data=b"baz")],
34+
schedule_to_close_timeout=Duration(seconds=5),
35+
priority=Priority(),
36+
),
37+
user_metadata=UserMetadata(
38+
summary=Payload(data=b"Summary")
39+
),
40+
)
41+
],
42+
),
43+
)
44+
45+
async def visitor(payload: Payload) -> Payload:
46+
# Mark visited by prefixing data
47+
new_payload = Payload()
48+
new_payload.metadata.update(payload.metadata)
49+
new_payload.data = b"visited:" + payload.data
50+
return new_payload
51+
52+
await visit_message(visitor, comp)
53+
54+
cmd = comp.successful.commands[0]
55+
sa = cmd.schedule_activity
56+
assert sa.headers["foo"].data == b"visited:bar"
57+
assert len(sa.arguments) == 1 and sa.arguments[0].data == b"visited:baz"
58+
59+
assert cmd.user_metadata.summary.data == b"visited:Summary"
60+
61+
62+
async def test_visit_payloads_on_other_commands():
63+
comp = WorkflowActivationCompletion(
64+
run_id="2",
65+
successful=Success(
66+
commands=[
67+
# Continue as new
68+
WorkflowCommand(
69+
continue_as_new_workflow_execution=ContinueAsNewWorkflowExecution(
70+
arguments=[Payload(data=b"a1")],
71+
headers={"h1": Payload(data=b"a2")},
72+
memo={"m1": Payload(data=b"a3")},
73+
)
74+
),
75+
# Start child
76+
WorkflowCommand(
77+
start_child_workflow_execution=StartChildWorkflowExecution(
78+
input=[Payload(data=b"b1")],
79+
headers={"h2": Payload(data=b"b2")},
80+
memo={"m2": Payload(data=b"b3")},
81+
)
82+
),
83+
# Signal external
84+
WorkflowCommand(
85+
signal_external_workflow_execution=SignalExternalWorkflowExecution(
86+
args=[Payload(data=b"c1")],
87+
headers={"h3": Payload(data=b"c2")},
88+
)
89+
),
90+
# Schedule local activity
91+
WorkflowCommand(
92+
schedule_local_activity=ScheduleLocalActivity(
93+
arguments=[Payload(data=b"d1")],
94+
headers={"h4": Payload(data=b"d2")},
95+
)
96+
),
97+
# Update response completed
98+
WorkflowCommand(
99+
update_response=UpdateResponse(
100+
completed=Payload(data=b"e1"),
101+
)
102+
),
103+
]
104+
),
105+
)
106+
107+
async def visitor(payload: Payload) -> Payload:
108+
new_payload = Payload()
109+
new_payload.metadata.update(payload.metadata)
110+
new_payload.data = b"visited:" + payload.data
111+
return new_payload
112+
113+
await visit_message(visitor, comp)
114+
115+
cmds = comp.successful.commands
116+
can = cmds[0].continue_as_new_workflow_execution
117+
assert can.arguments[0].data == b"visited:a1"
118+
assert can.headers["h1"].data == b"visited:a2"
119+
assert can.memo["m1"].data == b"visited:a3"
120+
121+
sc = cmds[1].start_child_workflow_execution
122+
assert sc.input[0].data == b"visited:b1"
123+
assert sc.headers["h2"].data == b"visited:b2"
124+
assert sc.memo["m2"].data == b"visited:b3"
125+
126+
se = cmds[2].signal_external_workflow_execution
127+
assert se.args[0].data == b"visited:c1"
128+
assert se.headers["h3"].data == b"visited:c2"
129+
130+
sla = cmds[3].schedule_local_activity
131+
assert sla.arguments[0].data == b"visited:d1"
132+
assert sla.headers["h4"].data == b"visited:d2"
133+
134+
ur = cmds[4].update_response
135+
assert ur.completed.data == b"visited:e1"

0 commit comments

Comments
 (0)