From 240756977c7cb5fca701afa34cc15d4f1f95119a Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 2 Sep 2025 12:32:01 -0700 Subject: [PATCH 01/14] Reflection based visitor --- temporalio/bridge/visitor.py | 53 ++++++++++++ temporalio/bridge/worker.py | 156 +++------------------------------- tests/worker/test_visitor.py | 135 +++++++++++++++++++++++++++++ tests/worker/test_workflow.py | 3 + 4 files changed, 205 insertions(+), 142 deletions(-) create mode 100644 temporalio/bridge/visitor.py create mode 100644 tests/worker/test_visitor.py diff --git a/temporalio/bridge/visitor.py b/temporalio/bridge/visitor.py new file mode 100644 index 000000000..aec1751a4 --- /dev/null +++ b/temporalio/bridge/visitor.py @@ -0,0 +1,53 @@ +from typing import Awaitable, Callable, Any + +from collections.abc import Mapping as AbcMapping, Sequence as AbcSequence + +from google.protobuf.descriptor import FieldDescriptor +from google.protobuf.message import Message + +from temporalio.api.common.v1.message_pb2 import Payload + + +async def visit_payloads( + f: Callable[[Payload], Awaitable[Payload]], root: Any +) -> None: + print("Visiting object: ", type(root)) + if isinstance(root, Payload): + print("Applying to payload: ", root) + root.CopyFrom(await f(root)) + print("Applied to payload: ", root) + elif isinstance(root, AbcMapping): + for k, v in root.items(): + await visit_payloads(f, k) + await visit_payloads(f, v) + elif isinstance(root, AbcSequence) and not isinstance( + root, (bytes, bytearray, str) + ): + for o in root: + await visit_payloads(f, o) + elif isinstance(root, Message): + await visit_message(f, root) + + +async def visit_message( + f: Callable[[Payload], Awaitable[Payload]], root: Message +) -> None: + print("Visiting Message: ", type(root)) + for field in root.DESCRIPTOR.fields: + print("Evaluating Field: ", field.name) + + # Repeated fields (including maps which are represented as repeated messages) + if field.label == FieldDescriptor.LABEL_REPEATED: + value = getattr(root, field.name) + if field.message_type is not None and field.message_type.GetOptions().map_entry: + for k, v in value.items(): + await visit_payloads(f, k) + await visit_payloads(f, v) + else: + for item in value: + await visit_payloads(f, item) + else: + # Only descend into singular message fields if present + if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name): + value = getattr(root, field.name) + await visit_payloads(f, value) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index e4cb05eee..8c66aaaa5 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,6 +20,7 @@ ) import google.protobuf.internal.containers +from google.protobuf.message import Message from typing_extensions import TypeAlias import temporalio.api.common.v1 @@ -39,6 +40,9 @@ ) from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore +from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.bridge.visitor import visit_payloads, visit_message + @dataclass class WorkerConfig: @@ -368,15 +372,9 @@ async def _encode_payloads( codec: temporalio.converter.PayloadCodec, ) -> None: """Encode payloads with the given codec.""" - return await _apply_to_payloads(payloads, codec.encode) - - -async def _encode_payload( - payload: temporalio.api.common.v1.Payload, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode a payload with the given codec.""" - return await _apply_to_payload(payload, codec.encode) + async def visitor(payload: Payload) -> Payload: + return (await codec.encode([payload]))[0] + return await visit_payloads(visitor, payloads) async def decode_activation( @@ -385,77 +383,10 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode the given activation with the codec.""" - for job in act.jobs: - if job.HasField("query_workflow"): - await _decode_payloads(job.query_workflow.arguments, codec) - if decode_headers: - await _decode_headers(job.query_workflow.headers, codec) - elif job.HasField("resolve_activity"): - if job.resolve_activity.result.HasField("cancelled"): - await codec.decode_failure( - job.resolve_activity.result.cancelled.failure - ) - elif job.resolve_activity.result.HasField("completed"): - if job.resolve_activity.result.completed.HasField("result"): - await _decode_payload( - job.resolve_activity.result.completed.result, codec - ) - elif job.resolve_activity.result.HasField("failed"): - await codec.decode_failure(job.resolve_activity.result.failed.failure) - elif job.HasField("resolve_child_workflow_execution"): - if job.resolve_child_workflow_execution.result.HasField("cancelled"): - await codec.decode_failure( - job.resolve_child_workflow_execution.result.cancelled.failure - ) - elif job.resolve_child_workflow_execution.result.HasField( - "completed" - ) and job.resolve_child_workflow_execution.result.completed.HasField( - "result" - ): - await _decode_payload( - job.resolve_child_workflow_execution.result.completed.result, codec - ) - elif job.resolve_child_workflow_execution.result.HasField("failed"): - await codec.decode_failure( - job.resolve_child_workflow_execution.result.failed.failure - ) - elif job.HasField("resolve_child_workflow_execution_start"): - if job.resolve_child_workflow_execution_start.HasField("cancelled"): - await codec.decode_failure( - job.resolve_child_workflow_execution_start.cancelled.failure - ) - elif job.HasField("resolve_request_cancel_external_workflow"): - if job.resolve_request_cancel_external_workflow.HasField("failure"): - await codec.decode_failure( - job.resolve_request_cancel_external_workflow.failure - ) - elif job.HasField("resolve_signal_external_workflow"): - if job.resolve_signal_external_workflow.HasField("failure"): - await codec.decode_failure(job.resolve_signal_external_workflow.failure) - elif job.HasField("signal_workflow"): - await _decode_payloads(job.signal_workflow.input, codec) - if decode_headers: - await _decode_headers(job.signal_workflow.headers, codec) - elif job.HasField("initialize_workflow"): - await _decode_payloads(job.initialize_workflow.arguments, codec) - if decode_headers: - await _decode_headers(job.initialize_workflow.headers, codec) - if job.initialize_workflow.HasField("continued_failure"): - await codec.decode_failure(job.initialize_workflow.continued_failure) - for val in job.initialize_workflow.memo.fields.values(): - # This uses API payload not bridge payload - new_payload = (await codec.decode([val]))[0] - # Make a shallow copy, in case new_payload.metadata and val.metadata are - # references to the same memory, e.g. decode() returns its input unchanged. - new_metadata = dict(new_payload.metadata) - val.metadata.clear() - val.metadata.update(new_metadata) - val.data = new_payload.data - elif job.HasField("do_update"): - await _decode_payloads(job.do_update.input, codec) - if decode_headers: - await _decode_headers(job.do_update.headers, codec) + async def visitor(payload: Payload) -> Payload: + return (await codec.decode([payload]))[0] + await visit_message(visitor, act) async def encode_completion( comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, @@ -463,66 +394,7 @@ async def encode_completion( encode_headers: bool, ) -> None: """Recursively encode the given completion with the codec.""" - if comp.HasField("failed"): - await codec.encode_failure(comp.failed.failure) - elif comp.HasField("successful"): - for command in comp.successful.commands: - if command.HasField("complete_workflow_execution"): - if command.complete_workflow_execution.HasField("result"): - await _encode_payload( - command.complete_workflow_execution.result, codec - ) - elif command.HasField("continue_as_new_workflow_execution"): - await _encode_payloads( - command.continue_as_new_workflow_execution.arguments, codec - ) - if encode_headers: - await _encode_headers( - command.continue_as_new_workflow_execution.headers, codec - ) - for val in command.continue_as_new_workflow_execution.memo.values(): - await _encode_payload(val, codec) - elif command.HasField("fail_workflow_execution"): - await codec.encode_failure(command.fail_workflow_execution.failure) - elif command.HasField("respond_to_query"): - if command.respond_to_query.HasField("failed"): - await codec.encode_failure(command.respond_to_query.failed) - elif command.respond_to_query.HasField( - "succeeded" - ) and command.respond_to_query.succeeded.HasField("response"): - await _encode_payload( - command.respond_to_query.succeeded.response, codec - ) - elif command.HasField("schedule_activity"): - await _encode_payloads(command.schedule_activity.arguments, codec) - if encode_headers: - await _encode_headers(command.schedule_activity.headers, codec) - elif command.HasField("schedule_local_activity"): - await _encode_payloads(command.schedule_local_activity.arguments, codec) - if encode_headers: - await _encode_headers( - command.schedule_local_activity.headers, codec - ) - elif command.HasField("signal_external_workflow_execution"): - await _encode_payloads( - command.signal_external_workflow_execution.args, codec - ) - if encode_headers: - await _encode_headers( - command.signal_external_workflow_execution.headers, codec - ) - elif command.HasField("start_child_workflow_execution"): - await _encode_payloads( - command.start_child_workflow_execution.input, codec - ) - if encode_headers: - await _encode_headers( - command.start_child_workflow_execution.headers, codec - ) - for val in command.start_child_workflow_execution.memo.values(): - await _encode_payload(val, codec) - elif command.HasField("update_response"): - if command.update_response.HasField("completed"): - await _encode_payload(command.update_response.completed, codec) - elif command.update_response.HasField("rejected"): - await codec.encode_failure(command.update_response.rejected) + async def visitor(payload: Payload) -> Payload: + return (await codec.encode([payload]))[0] + + await visit_message(visitor, comp) \ No newline at end of file diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py new file mode 100644 index 000000000..8c81fd51a --- /dev/null +++ b/tests/worker/test_visitor.py @@ -0,0 +1,135 @@ +from google.protobuf.duration_pb2 import Duration + +from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata +from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( + WorkflowCommand, + ScheduleActivity, + ScheduleLocalActivity, + ContinueAsNewWorkflowExecution, + StartChildWorkflowExecution, + SignalExternalWorkflowExecution, + UpdateResponse, +) +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( + Success, + WorkflowActivationCompletion, +) +from temporalio.bridge.visitor import visit_message +from temporalio.api.common.v1.message_pb2 import Payload, Priority + + +async def test_visit_payloads_mutates_all_payloads_in_message(): + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=1, + activity_id="1", + activity_type="", + task_queue="", + headers={"foo": Payload(data=b"bar")}, + arguments=[Payload(data=b"baz")], + schedule_to_close_timeout=Duration(seconds=5), + priority=Priority(), + ), + user_metadata=UserMetadata( + summary=Payload(data=b"Summary") + ), + ) + ], + ), + ) + + async def visitor(payload: Payload) -> Payload: + # Mark visited by prefixing data + new_payload = Payload() + new_payload.metadata.update(payload.metadata) + new_payload.data = b"visited:" + payload.data + return new_payload + + await visit_message(visitor, comp) + + cmd = comp.successful.commands[0] + sa = cmd.schedule_activity + assert sa.headers["foo"].data == b"visited:bar" + assert len(sa.arguments) == 1 and sa.arguments[0].data == b"visited:baz" + + assert cmd.user_metadata.summary.data == b"visited:Summary" + + +async def test_visit_payloads_on_other_commands(): + comp = WorkflowActivationCompletion( + run_id="2", + successful=Success( + commands=[ + # Continue as new + WorkflowCommand( + continue_as_new_workflow_execution=ContinueAsNewWorkflowExecution( + arguments=[Payload(data=b"a1")], + headers={"h1": Payload(data=b"a2")}, + memo={"m1": Payload(data=b"a3")}, + ) + ), + # Start child + WorkflowCommand( + start_child_workflow_execution=StartChildWorkflowExecution( + input=[Payload(data=b"b1")], + headers={"h2": Payload(data=b"b2")}, + memo={"m2": Payload(data=b"b3")}, + ) + ), + # Signal external + WorkflowCommand( + signal_external_workflow_execution=SignalExternalWorkflowExecution( + args=[Payload(data=b"c1")], + headers={"h3": Payload(data=b"c2")}, + ) + ), + # Schedule local activity + WorkflowCommand( + schedule_local_activity=ScheduleLocalActivity( + arguments=[Payload(data=b"d1")], + headers={"h4": Payload(data=b"d2")}, + ) + ), + # Update response completed + WorkflowCommand( + update_response=UpdateResponse( + completed=Payload(data=b"e1"), + ) + ), + ] + ), + ) + + async def visitor(payload: Payload) -> Payload: + new_payload = Payload() + new_payload.metadata.update(payload.metadata) + new_payload.data = b"visited:" + payload.data + return new_payload + + await visit_message(visitor, comp) + + cmds = comp.successful.commands + can = cmds[0].continue_as_new_workflow_execution + assert can.arguments[0].data == b"visited:a1" + assert can.headers["h1"].data == b"visited:a2" + assert can.memo["m1"].data == b"visited:a3" + + sc = cmds[1].start_child_workflow_execution + assert sc.input[0].data == b"visited:b1" + assert sc.headers["h2"].data == b"visited:b2" + assert sc.memo["m2"].data == b"visited:b3" + + se = cmds[2].signal_external_workflow_execution + assert se.args[0].data == b"visited:c1" + assert se.headers["h3"].data == b"visited:c2" + + sla = cmds[3].schedule_local_activity + assert sla.arguments[0].data == b"visited:d1" + assert sla.headers["h4"].data == b"visited:d2" + + ur = cmds[4].update_response + assert ur.completed.data == b"visited:e1" \ No newline at end of file diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index e97bf3e02..82a22028b 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8275,6 +8275,7 @@ async def test_workflow_headers_with_codec( "Temporal", id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=1), ) assert await workflow_handle.result() == "Hello, Temporal!" @@ -8288,6 +8289,7 @@ async def test_workflow_headers_with_codec( SignalAndQueryWorkflow.run, id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=1), ) # Simple signals and queries @@ -8327,3 +8329,4 @@ async def test_workflow_headers_with_codec( assert headers["foo"].data == b"bar" else: assert headers["foo"].data != b"bar" + assert False \ No newline at end of file From 7ffa586ba2b2c6849410b4faea675bc880aedf0f Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 2 Sep 2025 16:24:07 -0700 Subject: [PATCH 02/14] Code generation work --- scripts/_proto/Dockerfile | 1 + scripts/gen_protos.py | 3 +- scripts/gen_visitors.py | 162 +++++++++ temporalio/bridge/visitor.py | 101 +++--- temporalio/bridge/visitor_generated.py | 333 ++++++++++++++++++ temporalio/bridge/worker.py | 19 +- temporalio/worker/_workflow.py | 1 + temporalio/worker/_workflow_instance.py | 20 +- temporalio/worker/workflow_sandbox/_runner.py | 3 + temporalio/workflow.py | 16 + tests/test_client.py | 41 +++ tests/worker/test_visitor.py | 107 +++++- 12 files changed, 732 insertions(+), 75 deletions(-) create mode 100644 scripts/gen_visitors.py create mode 100644 temporalio/bridge/visitor_generated.py diff --git a/scripts/_proto/Dockerfile b/scripts/_proto/Dockerfile index 47f3c60dc..5227d883a 100644 --- a/scripts/_proto/Dockerfile +++ b/scripts/_proto/Dockerfile @@ -10,6 +10,7 @@ COPY ./ ./ RUN mkdir -p ./temporalio/api RUN uv add "protobuf<4" RUN uv sync --all-extras +RUN poe build-develop RUN poe gen-protos CMD cp -r ./temporalio/api/* /api_new && cp -r ./temporalio/bridge/proto/* /bridge_new diff --git a/scripts/gen_protos.py b/scripts/gen_protos.py index 61d2709fe..958c49e2b 100644 --- a/scripts/gen_protos.py +++ b/scripts/gen_protos.py @@ -6,7 +6,7 @@ import tempfile from functools import partial from pathlib import Path -from typing import List, Mapping, Optional +from typing import List, Mapping base_dir = Path(__file__).parent.parent proto_dir = ( @@ -201,7 +201,6 @@ def generate_protos(output_dir: Path): / v, ) - if __name__ == "__main__": check_proto_toolchain_versions() print("Generating protos...", file=sys.stderr) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py new file mode 100644 index 000000000..c6bdb2b09 --- /dev/null +++ b/scripts/gen_visitors.py @@ -0,0 +1,162 @@ +import sys +from pathlib import Path + +from google.protobuf.descriptor import Descriptor, FieldDescriptor + +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + WorkflowActivation, +) +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( + WorkflowActivationCompletion, +) + +base_dir = Path(__file__).parent.parent + +def gen_workflow_activation_payload_visitor_code() -> str: + """ + Generate Python source code that, given a function f(Payload) -> Payload, + applies it to every Payload contained within a WorkflowActivation tree. + + The generated code defines async visitor functions for each reachable + protobuf message type starting from WorkflowActivation, including support + for repeated fields and map entries, and a convenience entrypoint + function `visit_workflow_activation_payloads`. + """ + def name_for(desc: Descriptor) -> str: + # Use fully-qualified name to avoid collisions; replace dots with underscores + return desc.full_name.replace('.', '_') + + def emit_loop(lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str) -> None: + # Helper to emit a for-loop over a collection with optional headers guard + if field_name == "headers": + lines.append(" if not self.skip_headers:") + lines.append(f" for {var_name} in {iter_expr}:") + lines.append(f" await self.visit_{child_method}(f, {var_name})") + else: + lines.append(f" for {var_name} in {iter_expr}:") + lines.append(f" await self.visit_{child_method}(f, {var_name})") + + def emit_singular(lines: list[str], field_name: str, access_expr: str, child_method: str) -> None: + # Helper to emit a singular field visit with presence check and optional headers guard + if field_name == "headers": + lines.append(" if not self.skip_headers:") + lines.append(f" if o.HasField('{field_name}'):") + lines.append(f" await self.visit_{child_method}(f, {access_expr})") + else: + lines.append(f" if o.HasField('{field_name}'):") + lines.append(f" await self.visit_{child_method}(f, {access_expr})") + + # Track which message descriptors have visitor methods generated + generated: dict[str, bool] = {} + in_progress: set[str] = set() + methods: list[str] = [] + + def walk(desc: Descriptor) -> bool: + key = desc.full_name + if key in generated: + return generated[key] + if key in in_progress: + # Break cycles; if another path proves this node needed, we'll revisit + return False + + if desc.full_name == Payload.DESCRIPTOR.full_name: + generated[key] = True + methods.append( + """ async def visit_temporal_api_common_v1_Payload(self, f, o): + o.CopyFrom(await f(o)) +""" + ) + return True + + needed = False + in_progress.add(key) + lines: list[str] = [f" async def visit_{name_for(desc)}(self, f, o):"] + # If this is the SearchAttributes message, allow skipping + if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: + lines.append(" if self.skip_search_attributes:") + lines.append(" return") + + for field in desc.fields: + if field.type != FieldDescriptor.TYPE_MESSAGE: + continue + + # Repeated fields (including maps which are represented as repeated messages) + if field.label == FieldDescriptor.LABEL_REPEATED: + if field.message_type is not None and field.message_type.GetOptions().map_entry: + entry_desc = field.message_type + key_fd = entry_desc.fields_by_name.get("key") + val_fd = entry_desc.fields_by_name.get("value") + + if val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE: + child_desc = val_fd.message_type + child_needed = walk(child_desc) + needed |= child_needed + if child_needed: + emit_loop(lines, field.name, f"o.{field.name}.values()", "v", name_for(child_desc)) + + if key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE: + key_desc = key_fd.message_type + child_needed = walk(key_desc) + needed |= child_needed + if child_needed: + emit_loop(lines, field.name, f"o.{field.name}.keys()", "k", name_for(key_desc)) + else: + child_desc = field.message_type + child_needed = walk(child_desc) + needed |= child_needed + if child_needed: + emit_loop(lines, field.name, f"o.{field.name}", "v", name_for(child_desc)) + else: + child_desc = field.message_type + child_needed = walk(child_desc) + needed |= child_needed + if child_needed: + emit_singular(lines, field.name, f"o.{field.name}", name_for(child_desc)) + + generated[key] = needed + in_progress.discard(key) + if needed: + methods.append("\n".join(lines) + "\n") + return needed + + # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, + # and all messages from selected API modules + roots: list[Descriptor] = [ + WorkflowActivation.DESCRIPTOR, + WorkflowActivationCompletion.DESCRIPTOR, + ] + + # We avoid importing google.api deps in service protos; expand by walking from + # WorkflowActivationCompletion root which references many command messages. + + for r in roots: + walk(r) + + header = ( + "from typing import Awaitable, Callable, Any\n\n" + "from temporalio.api.common.v1.message_pb2 import Payload\n\n\n" + "class PayloadVisitor:\n" + " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n" + " self.skip_search_attributes = skip_search_attributes\n" + " self.skip_headers = skip_headers\n\n" + " async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n" + " method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n" + " method = getattr(self, method_name, None)\n" + " if method is not None:\n" + " await method(f, root)\n\n" + ) + + return header + "\n".join(methods) + + +def write_generated_visitors_into_visitor_generated_py() -> None: + """Write the generated visitor code into visitor_generated.py.""" + out_path = base_dir / "temporalio" / "bridge" / "visitor_generated.py" + code = gen_workflow_activation_payload_visitor_code() + out_path.write_text(code) + +if __name__ == "__main__": + print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr) + write_generated_visitors_into_visitor_generated_py() + diff --git a/temporalio/bridge/visitor.py b/temporalio/bridge/visitor.py index aec1751a4..c93d23e78 100644 --- a/temporalio/bridge/visitor.py +++ b/temporalio/bridge/visitor.py @@ -1,53 +1,58 @@ -from typing import Awaitable, Callable, Any - -from collections.abc import Mapping as AbcMapping, Sequence as AbcSequence +from collections.abc import Mapping as AbcMapping +from collections.abc import Sequence as AbcSequence +from typing import Any, Awaitable, Callable from google.protobuf.descriptor import FieldDescriptor from google.protobuf.message import Message -from temporalio.api.common.v1.message_pb2 import Payload - - -async def visit_payloads( - f: Callable[[Payload], Awaitable[Payload]], root: Any -) -> None: - print("Visiting object: ", type(root)) - if isinstance(root, Payload): - print("Applying to payload: ", root) - root.CopyFrom(await f(root)) - print("Applied to payload: ", root) - elif isinstance(root, AbcMapping): - for k, v in root.items(): - await visit_payloads(f, k) - await visit_payloads(f, v) - elif isinstance(root, AbcSequence) and not isinstance( - root, (bytes, bytearray, str) - ): - for o in root: - await visit_payloads(f, o) - elif isinstance(root, Message): - await visit_message(f, root) - - -async def visit_message( - f: Callable[[Payload], Awaitable[Payload]], root: Message -) -> None: - print("Visiting Message: ", type(root)) - for field in root.DESCRIPTOR.fields: - print("Evaluating Field: ", field.name) - - # Repeated fields (including maps which are represented as repeated messages) - if field.label == FieldDescriptor.LABEL_REPEATED: - value = getattr(root, field.name) - if field.message_type is not None and field.message_type.GetOptions().map_entry: - for k, v in value.items(): - await visit_payloads(f, k) - await visit_payloads(f, v) - else: - for item in value: - await visit_payloads(f, item) - else: - # Only descend into singular message fields if present - if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name): +from temporalio.api.common.v1.message_pb2 import Payload, SearchAttributes + + +class PayloadVisitor: + def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False): + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit_payloads( + self, f: Callable[[Payload], Awaitable[Payload]], root: Any + ) -> None: + if self.skip_search_attributes and isinstance(root, SearchAttributes): + return + + if isinstance(root, Payload): + root.CopyFrom(await f(root)) + elif isinstance(root, AbcMapping): + for k, v in root.items(): + await self.visit_payloads(f, k) + await self.visit_payloads(f, v) + elif isinstance(root, AbcSequence) and not isinstance( + root, (bytes, bytearray, str) + ): + for o in root: + await self.visit_payloads(f, o) + elif isinstance(root, Message): + await self.visit_message(f, root,) + + + async def visit_message( + self, f: Callable[[Payload], Awaitable[Payload]], root: Message + ) -> None: + for field in root.DESCRIPTOR.fields: + if self.skip_headers and field.name == "headers": + continue + + # Repeated fields (including maps which are represented as repeated messages) + if field.label == FieldDescriptor.LABEL_REPEATED: value = getattr(root, field.name) - await visit_payloads(f, value) + if field.message_type is not None and field.message_type.GetOptions().map_entry: + for k, v in value.items(): + await self.visit_payloads(f, k) + await self.visit_payloads(f, v) + else: + for item in value: + await self.visit_payloads(f, item) + else: + # Only descend into singular message fields if present + if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name): + value = getattr(root, field.name) + await self.visit_payloads(f, value) diff --git a/temporalio/bridge/visitor_generated.py b/temporalio/bridge/visitor_generated.py new file mode 100644 index 000000000..995485b4e --- /dev/null +++ b/temporalio/bridge/visitor_generated.py @@ -0,0 +1,333 @@ +from typing import Awaitable, Callable, Any + +from temporalio.api.common.v1.message_pb2 import Payload + + +class PayloadVisitor: + def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False): + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None: + method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_') + method = getattr(self, method_name, None) + if method is not None: + await method(f, root) + + async def visit_temporal_api_common_v1_Payload(self, f, o): + o.CopyFrom(await f(o)) + + async def visit_temporal_api_common_v1_Payloads(self, f, o): + for v in o.payloads: + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_temporal_api_failure_v1_ApplicationFailureInfo(self, f, o): + if o.HasField('details'): + await self.visit_temporal_api_common_v1_Payloads(f, o.details) + + async def visit_temporal_api_failure_v1_TimeoutFailureInfo(self, f, o): + if o.HasField('last_heartbeat_details'): + await self.visit_temporal_api_common_v1_Payloads(f, o.last_heartbeat_details) + + async def visit_temporal_api_failure_v1_CanceledFailureInfo(self, f, o): + if o.HasField('details'): + await self.visit_temporal_api_common_v1_Payloads(f, o.details) + + async def visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, f, o): + if o.HasField('last_heartbeat_details'): + await self.visit_temporal_api_common_v1_Payloads(f, o.last_heartbeat_details) + + async def visit_temporal_api_failure_v1_Failure(self, f, o): + if o.HasField('encoded_attributes'): + await self.visit_temporal_api_common_v1_Payload(f, o.encoded_attributes) + if o.HasField('application_failure_info'): + await self.visit_temporal_api_failure_v1_ApplicationFailureInfo(f, o.application_failure_info) + if o.HasField('timeout_failure_info'): + await self.visit_temporal_api_failure_v1_TimeoutFailureInfo(f, o.timeout_failure_info) + if o.HasField('canceled_failure_info'): + await self.visit_temporal_api_failure_v1_CanceledFailureInfo(f, o.canceled_failure_info) + if o.HasField('reset_workflow_failure_info'): + await self.visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(f, o.reset_workflow_failure_info) + + async def visit_temporal_api_common_v1_Memo(self, f, o): + for v in o.fields.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_temporal_api_common_v1_SearchAttributes(self, f, o): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_activation_InitializeWorkflow(self, f, o): + for v in o.arguments: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + if o.HasField('continued_failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.continued_failure) + if o.HasField('last_completion_result'): + await self.visit_temporal_api_common_v1_Payloads(f, o.last_completion_result) + if o.HasField('memo'): + await self.visit_temporal_api_common_v1_Memo(f, o.memo) + if o.HasField('search_attributes'): + await self.visit_temporal_api_common_v1_SearchAttributes(f, o.search_attributes) + + async def visit_coresdk_workflow_activation_QueryWorkflow(self, f, o): + for v in o.arguments: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_activation_SignalWorkflow(self, f, o): + for v in o.input: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_activity_result_Success(self, f, o): + if o.HasField('result'): + await self.visit_temporal_api_common_v1_Payload(f, o.result) + + async def visit_coresdk_activity_result_Failure(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_activity_result_Cancellation(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_activity_result_ActivityResolution(self, f, o): + if o.HasField('completed'): + await self.visit_coresdk_activity_result_Success(f, o.completed) + if o.HasField('failed'): + await self.visit_coresdk_activity_result_Failure(f, o.failed) + if o.HasField('cancelled'): + await self.visit_coresdk_activity_result_Cancellation(f, o.cancelled) + + async def visit_coresdk_workflow_activation_ResolveActivity(self, f, o): + if o.HasField('result'): + await self.visit_coresdk_activity_result_ActivityResolution(f, o.result) + + async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(self, f, o): + if o.HasField('cancelled'): + await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(f, o.cancelled) + + async def visit_coresdk_child_workflow_Success(self, f, o): + if o.HasField('result'): + await self.visit_temporal_api_common_v1_Payload(f, o.result) + + async def visit_coresdk_child_workflow_Failure(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_child_workflow_Cancellation(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_child_workflow_ChildWorkflowResult(self, f, o): + if o.HasField('completed'): + await self.visit_coresdk_child_workflow_Success(f, o.completed) + if o.HasField('failed'): + await self.visit_coresdk_child_workflow_Failure(f, o.failed) + if o.HasField('cancelled'): + await self.visit_coresdk_child_workflow_Cancellation(f, o.cancelled) + + async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(self, f, o): + if o.HasField('result'): + await self.visit_coresdk_child_workflow_ChildWorkflowResult(f, o.result) + + async def visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_workflow_activation_DoUpdate(self, f, o): + for v in o.input: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, f, o): + if o.HasField('failed'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failed) + + async def visit_coresdk_nexus_NexusOperationResult(self, f, o): + if o.HasField('completed'): + await self.visit_temporal_api_common_v1_Payload(f, o.completed) + if o.HasField('failed'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failed) + if o.HasField('cancelled'): + await self.visit_temporal_api_failure_v1_Failure(f, o.cancelled) + if o.HasField('timed_out'): + await self.visit_temporal_api_failure_v1_Failure(f, o.timed_out) + + async def visit_coresdk_workflow_activation_ResolveNexusOperation(self, f, o): + if o.HasField('result'): + await self.visit_coresdk_nexus_NexusOperationResult(f, o.result) + + async def visit_coresdk_workflow_activation_WorkflowActivationJob(self, f, o): + if o.HasField('initialize_workflow'): + await self.visit_coresdk_workflow_activation_InitializeWorkflow(f, o.initialize_workflow) + if o.HasField('query_workflow'): + await self.visit_coresdk_workflow_activation_QueryWorkflow(f, o.query_workflow) + if o.HasField('signal_workflow'): + await self.visit_coresdk_workflow_activation_SignalWorkflow(f, o.signal_workflow) + if o.HasField('resolve_activity'): + await self.visit_coresdk_workflow_activation_ResolveActivity(f, o.resolve_activity) + if o.HasField('resolve_child_workflow_execution_start'): + await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(f, o.resolve_child_workflow_execution_start) + if o.HasField('resolve_child_workflow_execution'): + await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(f, o.resolve_child_workflow_execution) + if o.HasField('resolve_signal_external_workflow'): + await self.visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(f, o.resolve_signal_external_workflow) + if o.HasField('resolve_request_cancel_external_workflow'): + await self.visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(f, o.resolve_request_cancel_external_workflow) + if o.HasField('do_update'): + await self.visit_coresdk_workflow_activation_DoUpdate(f, o.do_update) + if o.HasField('resolve_nexus_operation_start'): + await self.visit_coresdk_workflow_activation_ResolveNexusOperationStart(f, o.resolve_nexus_operation_start) + if o.HasField('resolve_nexus_operation'): + await self.visit_coresdk_workflow_activation_ResolveNexusOperation(f, o.resolve_nexus_operation) + + async def visit_coresdk_workflow_activation_WorkflowActivation(self, f, o): + for v in o.jobs: + await self.visit_coresdk_workflow_activation_WorkflowActivationJob(f, v) + + async def visit_temporal_api_sdk_v1_UserMetadata(self, f, o): + if o.HasField('summary'): + await self.visit_temporal_api_common_v1_Payload(f, o.summary) + if o.HasField('details'): + await self.visit_temporal_api_common_v1_Payload(f, o.details) + + async def visit_coresdk_workflow_commands_ScheduleActivity(self, f, o): + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.arguments: + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_QuerySuccess(self, f, o): + if o.HasField('response'): + await self.visit_temporal_api_common_v1_Payload(f, o.response) + + async def visit_coresdk_workflow_commands_QueryResult(self, f, o): + if o.HasField('succeeded'): + await self.visit_coresdk_workflow_commands_QuerySuccess(f, o.succeeded) + if o.HasField('failed'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failed) + + async def visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, f, o): + if o.HasField('result'): + await self.visit_temporal_api_common_v1_Payload(f, o.result) + + async def visit_coresdk_workflow_commands_FailWorkflowExecution(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(self, f, o): + for v in o.arguments: + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.memo.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.search_attributes.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, f, o): + for v in o.input: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.memo.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.search_attributes.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(self, f, o): + for v in o.args: + await self.visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_ScheduleLocalActivity(self, f, o): + if not self.skip_headers: + for v in o.headers.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + for v in o.arguments: + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(self, f, o): + for v in o.search_attributes.values(): + await self.visit_temporal_api_common_v1_Payload(f, v) + + async def visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, f, o): + if o.HasField('upserted_memo'): + await self.visit_temporal_api_common_v1_Memo(f, o.upserted_memo) + + async def visit_coresdk_workflow_commands_UpdateResponse(self, f, o): + if o.HasField('rejected'): + await self.visit_temporal_api_failure_v1_Failure(f, o.rejected) + if o.HasField('completed'): + await self.visit_temporal_api_common_v1_Payload(f, o.completed) + + async def visit_coresdk_workflow_commands_ScheduleNexusOperation(self, f, o): + if o.HasField('input'): + await self.visit_temporal_api_common_v1_Payload(f, o.input) + + async def visit_coresdk_workflow_commands_WorkflowCommand(self, f, o): + if o.HasField('user_metadata'): + await self.visit_temporal_api_sdk_v1_UserMetadata(f, o.user_metadata) + if o.HasField('schedule_activity'): + await self.visit_coresdk_workflow_commands_ScheduleActivity(f, o.schedule_activity) + if o.HasField('respond_to_query'): + await self.visit_coresdk_workflow_commands_QueryResult(f, o.respond_to_query) + if o.HasField('complete_workflow_execution'): + await self.visit_coresdk_workflow_commands_CompleteWorkflowExecution(f, o.complete_workflow_execution) + if o.HasField('fail_workflow_execution'): + await self.visit_coresdk_workflow_commands_FailWorkflowExecution(f, o.fail_workflow_execution) + if o.HasField('continue_as_new_workflow_execution'): + await self.visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(f, o.continue_as_new_workflow_execution) + if o.HasField('start_child_workflow_execution'): + await self.visit_coresdk_workflow_commands_StartChildWorkflowExecution(f, o.start_child_workflow_execution) + if o.HasField('signal_external_workflow_execution'): + await self.visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(f, o.signal_external_workflow_execution) + if o.HasField('schedule_local_activity'): + await self.visit_coresdk_workflow_commands_ScheduleLocalActivity(f, o.schedule_local_activity) + if o.HasField('upsert_workflow_search_attributes'): + await self.visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(f, o.upsert_workflow_search_attributes) + if o.HasField('modify_workflow_properties'): + await self.visit_coresdk_workflow_commands_ModifyWorkflowProperties(f, o.modify_workflow_properties) + if o.HasField('update_response'): + await self.visit_coresdk_workflow_commands_UpdateResponse(f, o.update_response) + if o.HasField('schedule_nexus_operation'): + await self.visit_coresdk_workflow_commands_ScheduleNexusOperation(f, o.schedule_nexus_operation) + + async def visit_coresdk_workflow_completion_Success(self, f, o): + for v in o.commands: + await self.visit_coresdk_workflow_commands_WorkflowCommand(f, v) + + async def visit_coresdk_workflow_completion_Failure(self, f, o): + if o.HasField('failure'): + await self.visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def visit_coresdk_workflow_completion_WorkflowActivationCompletion(self, f, o): + if o.HasField('successful'): + await self.visit_coresdk_workflow_completion_Success(f, o.successful) + if o.HasField('failed'): + await self.visit_coresdk_workflow_completion_Failure(f, o.failed) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 8c66aaaa5..9d3f48ab1 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -35,13 +35,12 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter import temporalio.exceptions +from temporalio.api.common.v1.message_pb2 import Payload from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore - -from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge.visitor import visit_payloads, visit_message +from temporalio.bridge.visitor import PayloadVisitor @dataclass @@ -367,16 +366,6 @@ async def _decode_payload( return await _apply_to_payload(payload, codec.decode) -async def _encode_payloads( - payloads: PayloadContainer, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Encode payloads with the given codec.""" - async def visitor(payload: Payload) -> Payload: - return (await codec.encode([payload]))[0] - return await visit_payloads(visitor, payloads) - - async def decode_activation( act: temporalio.bridge.proto.workflow_activation.WorkflowActivation, codec: temporalio.converter.PayloadCodec, @@ -386,7 +375,7 @@ async def decode_activation( async def visitor(payload: Payload) -> Payload: return (await codec.decode([payload]))[0] - await visit_message(visitor, act) + await PayloadVisitor(skip_search_attributes=True, skip_headers=not decode_headers).visit_message(visitor, act) async def encode_completion( comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, @@ -397,4 +386,4 @@ async def encode_completion( async def visitor(payload: Payload) -> Payload: return (await codec.encode([payload]))[0] - await visit_message(visitor, comp) \ No newline at end of file + await PayloadVisitor(skip_search_attributes=True, skip_headers=not encode_headers).visit_message(visitor, comp) \ No newline at end of file diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index c7b89206c..cd089c7bc 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -563,6 +563,7 @@ def _create_workflow_instance( extern_functions=self._extern_functions, disable_eager_activity_execution=self._disable_eager_activity_execution, worker_level_failure_exception_types=self._workflow_failure_exception_types, + last_completion_result=init.last_completion_result, ) if defn.sandboxed: return self._workflow_runner.create_instance(det) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index c93155672..7b212b296 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -64,6 +64,7 @@ import temporalio.workflow from temporalio.service import __version__ +from ..api.common.v1.message_pb2 import Payloads from ._interceptor import ( ContinueAsNewInput, ExecuteWorkflowInput, @@ -143,7 +144,7 @@ class WorkflowInstanceDetails: extern_functions: Mapping[str, Callable] disable_eager_activity_execution: bool worker_level_failure_exception_types: Sequence[Type[BaseException]] - + last_completion_result: Payloads class WorkflowInstance(ABC): """Instance of a workflow that can handle activations.""" @@ -320,6 +321,8 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: # metadata query self._current_details = "" + self._last_completion_result = det.last_completion_result + # The versioning behavior of this workflow, as established by annotation or by the dynamic # config function. Is only set once upon initialization. self._versioning_behavior: Optional[temporalio.common.VersioningBehavior] = None @@ -1686,6 +1689,17 @@ def workflow_set_current_details(self, details: str): self._assert_not_read_only("set current details") self._current_details = details + def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]: + print("workflow_last_completion_result: ", self._last_completion_result, type(self._last_completion_result), "payload length:", len(self._last_completion_result.payloads)) + if len(self._last_completion_result.payloads) == 0: + return None + elif len(self._last_completion_result.payloads) > 1: + warnings.warn(f"Expected single last completion result, got {len(self._last_completion_result.payloads)}") + return None + + print("Payload:", self._last_completion_result.payloads[0]) + return self._payload_converter.from_payload(self._last_completion_result.payloads[0], type_hint) + #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". @@ -2766,6 +2780,10 @@ def _apply_schedule_command( v.start_to_close_timeout.FromTimedelta(self._input.start_to_close_timeout) if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) + if self._input.summary: + command.user_metadata.summary.CopyFrom( + self._instance._payload_converter.to_payload(self._input.summary) + ) v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index ba1a7f3ce..960d02f6a 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -18,6 +18,8 @@ import temporalio.worker._workflow_instance import temporalio.workflow +from ...api.common.v1.message_pb2 import Payloads + # Workflow instance has to be relative import from .._workflow_instance import ( UnsandboxedWorkflowRunner, @@ -84,6 +86,7 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None: extern_functions={}, disable_eager_activity_execution=False, worker_level_failure_exception_types=self._worker_level_failure_exception_types, + last_completion_result=Payloads(), ), ) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 423d5289b..3c812ea3f 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -897,6 +897,9 @@ def workflow_get_current_details(self) -> str: ... @abstractmethod def workflow_set_current_details(self, details: str): ... + @abstractmethod + def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]: ... + _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( "__temporal_current_update_info" @@ -1039,6 +1042,19 @@ def get_current_details() -> str: return _Runtime.current().workflow_get_current_details() +@overload +def get_last_completion_result(type_hint: Type[ParamType]) -> Optional[ParamType]: ... + + +def get_last_completion_result(type_hint: Optional[Type] = None) -> Optional[Any]: + """Get the current details of the workflow which may appear in the UI/CLI. + Unlike static details set at start, this value can be updated throughout + the life of the workflow and is independent of the static details. + This can be in Temporal markdown format and can span multiple lines. + """ + return _Runtime.current().workflow_last_completion_result(type_hint) + + def set_current_details(description: str) -> None: """Set the current details of the workflow which may appear in the UI/CLI. Unlike static details set at start, this value can be updated throughout diff --git a/tests/test_client.py b/tests/test_client.py index 9c33e9e1c..d90bd0ad9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import dataclasses import json import os @@ -1501,3 +1502,43 @@ async def test_cloud_client_simple(): GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"]) ) assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace + + +@workflow.defn +class LastCompletionResultWorkflow: + @workflow.run + async def run(self) -> str: + last_result = workflow.get_last_completion_result(type_hint=str) + if last_result is not None: + return "From last completion:" + last_result + else: + return "My First Result" + + +async def test_schedule_last_completion_result( + client: Client, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support schedules") + + async with new_worker(client, LastCompletionResultWorkflow) as worker: + handle = await client.create_schedule( + f"schedule-{uuid.uuid4()}", + Schedule( + action=ScheduleActionStartWorkflow( + "LastCompletionResultWorkflow", + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ), + spec=ScheduleSpec(), + ), + ) + await handle.trigger() + await asyncio.sleep(1) + await handle.trigger() + await asyncio.sleep(1) + print(await handle.describe()) + + await handle.delete() + assert False + diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 8c81fd51a..65537ceda 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,24 +1,34 @@ from google.protobuf.duration_pb2 import Duration +from temporalio.api.common.v1.message_pb2 import ( + Payload, + Payloads, + Priority, + SearchAttributes, +) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + InitializeWorkflow, + WorkflowActivation, + WorkflowActivationJob, +) from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( - WorkflowCommand, + ContinueAsNewWorkflowExecution, ScheduleActivity, ScheduleLocalActivity, - ContinueAsNewWorkflowExecution, - StartChildWorkflowExecution, SignalExternalWorkflowExecution, + StartChildWorkflowExecution, UpdateResponse, + WorkflowCommand, ) from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( Success, WorkflowActivationCompletion, ) -from temporalio.bridge.visitor import visit_message -from temporalio.api.common.v1.message_pb2 import Payload, Priority +from temporalio.bridge.visitor_generated import PayloadVisitor -async def test_visit_payloads_mutates_all_payloads_in_message(): +async def test_workflow_activation_completion(): comp = WorkflowActivationCompletion( run_id="1", successful=Success( @@ -49,7 +59,7 @@ async def visitor(payload: Payload) -> Payload: new_payload.data = b"visited:" + payload.data return new_payload - await visit_message(visitor, comp) + await PayloadVisitor().visit(visitor, comp) cmd = comp.successful.commands[0] sa = cmd.schedule_activity @@ -59,6 +69,60 @@ async def visitor(payload: Payload) -> Payload: assert cmd.user_metadata.summary.data == b"visited:Summary" +async def test_workflow_activation(): + original = WorkflowActivation( + jobs=[ + WorkflowActivationJob( + initialize_workflow=InitializeWorkflow( + arguments=[ + Payload(data=b"repeated1"), + Payload(data=b"repeated2"), + ], + headers={ + "header":Payload(data=b"map") + }, + last_completion_result=Payloads( + payloads=[ + Payload(data=b"obj1"), + Payload(data=b"obj2"), + ] + ), + search_attributes=SearchAttributes( + indexed_fields={ + "sakey":Payload(data=b"saobj"), + } + ) + ), + ) + ] + ) + + async def visitor(payload: Payload) -> Payload: + # Mark visited by prefixing data + new_payload = Payload() + new_payload.metadata.update(payload.metadata) + new_payload.metadata["visited"] = b"True" + new_payload.data = payload.data + return new_payload + + act = original.__deepcopy__() + await PayloadVisitor().visit(visitor, act) + assert act.jobs[0].initialize_workflow.arguments[0].metadata["visited"] + assert act.jobs[0].initialize_workflow.arguments[1].metadata["visited"] + assert act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] + assert act.jobs[0].initialize_workflow.last_completion_result.payloads[0].metadata["visited"] + assert act.jobs[0].initialize_workflow.last_completion_result.payloads[1].metadata["visited"] + assert act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"] + + act = original.__deepcopy__() + await PayloadVisitor(skip_search_attributes=True).visit(visitor, act) + assert not act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"] + + act = original.__deepcopy__() + await PayloadVisitor(skip_headers=True).visit(visitor, act) + assert not act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] + + async def test_visit_payloads_on_other_commands(): comp = WorkflowActivationCompletion( run_id="2", @@ -110,7 +174,7 @@ async def visitor(payload: Payload) -> Payload: new_payload.data = b"visited:" + payload.data return new_payload - await visit_message(visitor, comp) + await PayloadVisitor().visit(visitor, comp) cmds = comp.successful.commands can = cmds[0].continue_as_new_workflow_execution @@ -132,4 +196,29 @@ async def visitor(payload: Payload) -> Payload: assert sla.headers["h4"].data == b"visited:d2" ur = cmds[4].update_response - assert ur.completed.data == b"visited:e1" \ No newline at end of file + assert ur.completed.data == b"visited:e1" + +async def test_code_gen(): + # Smoke test the generated visitor on a simple activation containing payloads + act = WorkflowActivation( + jobs=[ + WorkflowActivationJob( + initialize_workflow=InitializeWorkflow( + arguments=[Payload(data=b"x1"), Payload(data=b"x2")], + headers={"h": Payload(data=b"x3")}, + ) + ) + ] + ) + + async def _f(p: Payload) -> Payload: + q = Payload() + q.metadata.update(p.metadata) + q.data = b"v:" + p.data + return q + + await PayloadVisitor().visit(_f, act) + init = act.jobs[0].initialize_workflow + assert init.arguments[0].data == b"v:x1" + assert init.arguments[1].data == b"v:x2" + assert init.headers["h"].data == b"v:x3" \ No newline at end of file From 00ef52888d713aaac55e176ec6e479a3f2bfedb7 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 2 Sep 2025 16:25:26 -0700 Subject: [PATCH 03/14] Linting --- scripts/gen_protos.py | 1 + scripts/gen_visitors.py | 79 +++++++++++++++++++------ temporalio/bridge/visitor.py | 19 ++++-- temporalio/bridge/visitor_generated.py | 2 +- temporalio/bridge/worker.py | 11 +++- temporalio/worker/_workflow_instance.py | 21 +++++-- temporalio/workflow.py | 4 +- tests/test_client.py | 1 - tests/worker/test_visitor.py | 39 ++++++++---- tests/worker/test_workflow.py | 2 +- 10 files changed, 134 insertions(+), 45 deletions(-) diff --git a/scripts/gen_protos.py b/scripts/gen_protos.py index 958c49e2b..32c5c3af2 100644 --- a/scripts/gen_protos.py +++ b/scripts/gen_protos.py @@ -201,6 +201,7 @@ def generate_protos(output_dir: Path): / v, ) + if __name__ == "__main__": check_proto_toolchain_versions() print("Generating protos...", file=sys.stderr) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index c6bdb2b09..d4dad186c 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -13,6 +13,7 @@ base_dir = Path(__file__).parent.parent + def gen_workflow_activation_payload_visitor_code() -> str: """ Generate Python source code that, given a function f(Payload) -> Payload, @@ -23,29 +24,44 @@ def gen_workflow_activation_payload_visitor_code() -> str: for repeated fields and map entries, and a convenience entrypoint function `visit_workflow_activation_payloads`. """ + def name_for(desc: Descriptor) -> str: # Use fully-qualified name to avoid collisions; replace dots with underscores - return desc.full_name.replace('.', '_') - - def emit_loop(lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str) -> None: + return desc.full_name.replace(".", "_") + + def emit_loop( + lines: list[str], + field_name: str, + iter_expr: str, + var_name: str, + child_method: str, + ) -> None: # Helper to emit a for-loop over a collection with optional headers guard if field_name == "headers": lines.append(" if not self.skip_headers:") lines.append(f" for {var_name} in {iter_expr}:") - lines.append(f" await self.visit_{child_method}(f, {var_name})") + lines.append( + f" await self.visit_{child_method}(f, {var_name})" + ) else: lines.append(f" for {var_name} in {iter_expr}:") lines.append(f" await self.visit_{child_method}(f, {var_name})") - def emit_singular(lines: list[str], field_name: str, access_expr: str, child_method: str) -> None: + def emit_singular( + lines: list[str], field_name: str, access_expr: str, child_method: str + ) -> None: # Helper to emit a singular field visit with presence check and optional headers guard if field_name == "headers": lines.append(" if not self.skip_headers:") lines.append(f" if o.HasField('{field_name}'):") - lines.append(f" await self.visit_{child_method}(f, {access_expr})") + lines.append( + f" await self.visit_{child_method}(f, {access_expr})" + ) else: lines.append(f" if o.HasField('{field_name}'):") - lines.append(f" await self.visit_{child_method}(f, {access_expr})") + lines.append( + f" await self.visit_{child_method}(f, {access_expr})" + ) # Track which message descriptors have visitor methods generated generated: dict[str, bool] = {} @@ -83,36 +99,65 @@ def walk(desc: Descriptor) -> bool: # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: - if field.message_type is not None and field.message_type.GetOptions().map_entry: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): entry_desc = field.message_type key_fd = entry_desc.fields_by_name.get("key") val_fd = entry_desc.fields_by_name.get("value") - if val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE: + if ( + val_fd is not None + and val_fd.type == FieldDescriptor.TYPE_MESSAGE + ): child_desc = val_fd.message_type child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_loop(lines, field.name, f"o.{field.name}.values()", "v", name_for(child_desc)) - - if key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE: + emit_loop( + lines, + field.name, + f"o.{field.name}.values()", + "v", + name_for(child_desc), + ) + + if ( + key_fd is not None + and key_fd.type == FieldDescriptor.TYPE_MESSAGE + ): key_desc = key_fd.message_type child_needed = walk(key_desc) needed |= child_needed if child_needed: - emit_loop(lines, field.name, f"o.{field.name}.keys()", "k", name_for(key_desc)) + emit_loop( + lines, + field.name, + f"o.{field.name}.keys()", + "k", + name_for(key_desc), + ) else: child_desc = field.message_type child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_loop(lines, field.name, f"o.{field.name}", "v", name_for(child_desc)) + emit_loop( + lines, + field.name, + f"o.{field.name}", + "v", + name_for(child_desc), + ) else: child_desc = field.message_type child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_singular(lines, field.name, f"o.{field.name}", name_for(child_desc)) + emit_singular( + lines, field.name, f"o.{field.name}", name_for(child_desc) + ) generated[key] = needed in_progress.discard(key) @@ -134,7 +179,7 @@ def walk(desc: Descriptor) -> bool: walk(r) header = ( - "from typing import Awaitable, Callable, Any\n\n" + "from typing import Any, Awaitable, Callable\n\n" "from temporalio.api.common.v1.message_pb2 import Payload\n\n\n" "class PayloadVisitor:\n" " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n" @@ -156,7 +201,7 @@ def write_generated_visitors_into_visitor_generated_py() -> None: code = gen_workflow_activation_payload_visitor_code() out_path.write_text(code) + if __name__ == "__main__": print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr) write_generated_visitors_into_visitor_generated_py() - diff --git a/temporalio/bridge/visitor.py b/temporalio/bridge/visitor.py index c93d23e78..df4bb84ab 100644 --- a/temporalio/bridge/visitor.py +++ b/temporalio/bridge/visitor.py @@ -9,7 +9,9 @@ class PayloadVisitor: - def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False): + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers @@ -31,8 +33,10 @@ async def visit_payloads( for o in root: await self.visit_payloads(f, o) elif isinstance(root, Message): - await self.visit_message(f, root,) - + await self.visit_message( + f, + root, + ) async def visit_message( self, f: Callable[[Payload], Awaitable[Payload]], root: Message @@ -44,7 +48,10 @@ async def visit_message( # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: value = getattr(root, field.name) - if field.message_type is not None and field.message_type.GetOptions().map_entry: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): for k, v in value.items(): await self.visit_payloads(f, k) await self.visit_payloads(f, v) @@ -53,6 +60,8 @@ async def visit_message( await self.visit_payloads(f, item) else: # Only descend into singular message fields if present - if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField(field.name): + if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField( + field.name + ): value = getattr(root, field.name) await self.visit_payloads(f, value) diff --git a/temporalio/bridge/visitor_generated.py b/temporalio/bridge/visitor_generated.py index 995485b4e..42a8aa2b5 100644 --- a/temporalio/bridge/visitor_generated.py +++ b/temporalio/bridge/visitor_generated.py @@ -1,4 +1,4 @@ -from typing import Awaitable, Callable, Any +from typing import Any, Awaitable, Callable from temporalio.api.common.v1.message_pb2 import Payload diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 9d3f48ab1..189e838df 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -372,10 +372,14 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode the given activation with the codec.""" + async def visitor(payload: Payload) -> Payload: return (await codec.decode([payload]))[0] - await PayloadVisitor(skip_search_attributes=True, skip_headers=not decode_headers).visit_message(visitor, act) + await PayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit_message(visitor, act) + async def encode_completion( comp: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion, @@ -383,7 +387,10 @@ async def encode_completion( encode_headers: bool, ) -> None: """Recursively encode the given completion with the codec.""" + async def visitor(payload: Payload) -> Payload: return (await codec.encode([payload]))[0] - await PayloadVisitor(skip_search_attributes=True, skip_headers=not encode_headers).visit_message(visitor, comp) \ No newline at end of file + await PayloadVisitor( + skip_search_attributes=True, skip_headers=not encode_headers + ).visit_message(visitor, comp) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 7b212b296..ee5df6cba 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -146,6 +146,7 @@ class WorkflowInstanceDetails: worker_level_failure_exception_types: Sequence[Type[BaseException]] last_completion_result: Payloads + class WorkflowInstance(ABC): """Instance of a workflow that can handle activations.""" @@ -1689,16 +1690,28 @@ def workflow_set_current_details(self, details: str): self._assert_not_read_only("set current details") self._current_details = details - def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]: - print("workflow_last_completion_result: ", self._last_completion_result, type(self._last_completion_result), "payload length:", len(self._last_completion_result.payloads)) + def workflow_last_completion_result( + self, type_hint: Optional[Type] + ) -> Optional[Any]: + print( + "workflow_last_completion_result: ", + self._last_completion_result, + type(self._last_completion_result), + "payload length:", + len(self._last_completion_result.payloads), + ) if len(self._last_completion_result.payloads) == 0: return None elif len(self._last_completion_result.payloads) > 1: - warnings.warn(f"Expected single last completion result, got {len(self._last_completion_result.payloads)}") + warnings.warn( + f"Expected single last completion result, got {len(self._last_completion_result.payloads)}" + ) return None print("Payload:", self._last_completion_result.payloads[0]) - return self._payload_converter.from_payload(self._last_completion_result.payloads[0], type_hint) + return self._payload_converter.from_payload( + self._last_completion_result.payloads[0], type_hint + ) #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 3c812ea3f..19d280ca2 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -898,7 +898,9 @@ def workflow_get_current_details(self) -> str: ... def workflow_set_current_details(self, details: str): ... @abstractmethod - def workflow_last_completion_result(self, type_hint: Optional[Type]) -> Optional[Any]: ... + def workflow_last_completion_result( + self, type_hint: Optional[Type] + ) -> Optional[Any]: ... _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( diff --git a/tests/test_client.py b/tests/test_client.py index d90bd0ad9..773467163 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1541,4 +1541,3 @@ async def test_schedule_last_completion_result( await handle.delete() assert False - diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 65537ceda..87dd68e8f 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -44,9 +44,7 @@ async def test_workflow_activation_completion(): schedule_to_close_timeout=Duration(seconds=5), priority=Priority(), ), - user_metadata=UserMetadata( - summary=Payload(data=b"Summary") - ), + user_metadata=UserMetadata(summary=Payload(data=b"Summary")), ) ], ), @@ -78,9 +76,7 @@ async def test_workflow_activation(): Payload(data=b"repeated1"), Payload(data=b"repeated2"), ], - headers={ - "header":Payload(data=b"map") - }, + headers={"header": Payload(data=b"map")}, last_completion_result=Payloads( payloads=[ Payload(data=b"obj1"), @@ -89,9 +85,9 @@ async def test_workflow_activation(): ), search_attributes=SearchAttributes( indexed_fields={ - "sakey":Payload(data=b"saobj"), + "sakey": Payload(data=b"saobj"), } - ) + ), ), ) ] @@ -110,13 +106,29 @@ async def visitor(payload: Payload) -> Payload: assert act.jobs[0].initialize_workflow.arguments[0].metadata["visited"] assert act.jobs[0].initialize_workflow.arguments[1].metadata["visited"] assert act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] - assert act.jobs[0].initialize_workflow.last_completion_result.payloads[0].metadata["visited"] - assert act.jobs[0].initialize_workflow.last_completion_result.payloads[1].metadata["visited"] - assert act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"] + assert ( + act.jobs[0] + .initialize_workflow.last_completion_result.payloads[0] + .metadata["visited"] + ) + assert ( + act.jobs[0] + .initialize_workflow.last_completion_result.payloads[1] + .metadata["visited"] + ) + assert ( + act.jobs[0] + .initialize_workflow.search_attributes.indexed_fields["sakey"] + .metadata["visited"] + ) act = original.__deepcopy__() await PayloadVisitor(skip_search_attributes=True).visit(visitor, act) - assert not act.jobs[0].initialize_workflow.search_attributes.indexed_fields["sakey"].metadata["visited"] + assert ( + not act.jobs[0] + .initialize_workflow.search_attributes.indexed_fields["sakey"] + .metadata["visited"] + ) act = original.__deepcopy__() await PayloadVisitor(skip_headers=True).visit(visitor, act) @@ -198,6 +210,7 @@ async def visitor(payload: Payload) -> Payload: ur = cmds[4].update_response assert ur.completed.data == b"visited:e1" + async def test_code_gen(): # Smoke test the generated visitor on a simple activation containing payloads act = WorkflowActivation( @@ -221,4 +234,4 @@ async def _f(p: Payload) -> Payload: init = act.jobs[0].initialize_workflow assert init.arguments[0].data == b"v:x1" assert init.arguments[1].data == b"v:x2" - assert init.headers["h"].data == b"v:x3" \ No newline at end of file + assert init.headers["h"].data == b"v:x3" diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 82a22028b..279f3f0f8 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8329,4 +8329,4 @@ async def test_workflow_headers_with_codec( assert headers["foo"].data == b"bar" else: assert headers["foo"].data != b"bar" - assert False \ No newline at end of file + assert False From 379ab2773dd0f9a4262a0ee8014fc6f86d4e0cae Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 2 Sep 2025 17:02:28 -0700 Subject: [PATCH 04/14] Cleanup --- scripts/gen_visitors.py | 132 ++++---- temporalio/bridge/_visitor.py | 421 ++++++++++++++++++++++++ temporalio/bridge/visitor.py | 67 ---- temporalio/bridge/visitor_generated.py | 333 ------------------- temporalio/bridge/worker.py | 6 +- temporalio/worker/_workflow_instance.py | 16 +- temporalio/workflow.py | 13 - tests/test_client.py | 39 --- tests/worker/test_visitor.py | 2 +- 9 files changed, 507 insertions(+), 522 deletions(-) create mode 100644 temporalio/bridge/_visitor.py delete mode 100644 temporalio/bridge/visitor.py delete mode 100644 temporalio/bridge/visitor_generated.py diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index d4dad186c..e017e6d5a 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -1,3 +1,4 @@ +import subprocess import sys from pathlib import Path @@ -30,38 +31,33 @@ def name_for(desc: Descriptor) -> str: return desc.full_name.replace(".", "_") def emit_loop( - lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str, - ) -> None: + ) -> str: # Helper to emit a for-loop over a collection with optional headers guard if field_name == "headers": - lines.append(" if not self.skip_headers:") - lines.append(f" for {var_name} in {iter_expr}:") - lines.append( - f" await self.visit_{child_method}(f, {var_name})" - ) + return f"""\ + if not self.skip_headers: + for {var_name} in {iter_expr}: + await self._visit_{child_method}(f, {var_name})""" else: - lines.append(f" for {var_name} in {iter_expr}:") - lines.append(f" await self.visit_{child_method}(f, {var_name})") + return f"""\ + for {var_name} in {iter_expr}: + await self._visit_{child_method}(f, {var_name})""" - def emit_singular( - lines: list[str], field_name: str, access_expr: str, child_method: str - ) -> None: + def emit_singular(field_name: str, access_expr: str, child_method: str) -> str: # Helper to emit a singular field visit with presence check and optional headers guard if field_name == "headers": - lines.append(" if not self.skip_headers:") - lines.append(f" if o.HasField('{field_name}'):") - lines.append( - f" await self.visit_{child_method}(f, {access_expr})" - ) + return f"""\ + if not self.skip_headers: + if o.HasField("{field_name}"): + await self._visit_{child_method}(f, {access_expr})""" else: - lines.append(f" if o.HasField('{field_name}'):") - lines.append( - f" await self.visit_{child_method}(f, {access_expr})" - ) + return f"""\ + if o.HasField("{field_name}"): + await self._visit_{child_method}(f, {access_expr})""" # Track which message descriptors have visitor methods generated generated: dict[str, bool] = {} @@ -79,7 +75,7 @@ def walk(desc: Descriptor) -> bool: if desc.full_name == Payload.DESCRIPTOR.full_name: generated[key] = True methods.append( - """ async def visit_temporal_api_common_v1_Payload(self, f, o): + """ async def _visit_temporal_api_common_v1_Payload(self, f, o): o.CopyFrom(await f(o)) """ ) @@ -87,7 +83,7 @@ def walk(desc: Descriptor) -> bool: needed = False in_progress.add(key) - lines: list[str] = [f" async def visit_{name_for(desc)}(self, f, o):"] + lines: list[str] = [f" async def _visit_{name_for(desc)}(self, f, o):"] # If this is the SearchAttributes message, allow skipping if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: lines.append(" if self.skip_search_attributes:") @@ -115,12 +111,13 @@ def walk(desc: Descriptor) -> bool: child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_loop( - lines, - field.name, - f"o.{field.name}.values()", - "v", - name_for(child_desc), + lines.append( + emit_loop( + field.name, + f"o.{field.name}.values()", + "v", + name_for(child_desc), + ) ) if ( @@ -131,32 +128,36 @@ def walk(desc: Descriptor) -> bool: child_needed = walk(key_desc) needed |= child_needed if child_needed: - emit_loop( - lines, - field.name, - f"o.{field.name}.keys()", - "k", - name_for(key_desc), + lines.append( + emit_loop( + field.name, + f"o.{field.name}.keys()", + "k", + name_for(key_desc), + ) ) else: child_desc = field.message_type child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_loop( - lines, - field.name, - f"o.{field.name}", - "v", - name_for(child_desc), + lines.append( + emit_loop( + field.name, + f"o.{field.name}", + "v", + name_for(child_desc), + ) ) else: child_desc = field.message_type child_needed = walk(child_desc) needed |= child_needed if child_needed: - emit_singular( - lines, field.name, f"o.{field.name}", name_for(child_desc) + lines.append( + emit_singular( + field.name, f"o.{field.name}", name_for(child_desc) + ) ) generated[key] = needed @@ -178,30 +179,43 @@ def walk(desc: Descriptor) -> bool: for r in roots: walk(r) - header = ( - "from typing import Any, Awaitable, Callable\n\n" - "from temporalio.api.common.v1.message_pb2 import Payload\n\n\n" - "class PayloadVisitor:\n" - " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n" - " self.skip_search_attributes = skip_search_attributes\n" - " self.skip_headers = skip_headers\n\n" - " async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n" - " method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n" - " method = getattr(self, method_name, None)\n" - " if method is not None:\n" - " await method(f, root)\n\n" - ) + header = """ +from typing import Any, Awaitable, Callable + +from temporalio.api.common.v1.message_pb2 import Payload + + +class PayloadVisitor: + \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\" + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + \"\"\"Creates a new payload visitor.\"\"\" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit( + self, f: Callable[[Payload], Awaitable[Payload]], root: Any + ) -> None: + \"\"\"Visits the given root message with the given function.\"\"\" + method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(f, root) + +""" return header + "\n".join(methods) def write_generated_visitors_into_visitor_generated_py() -> None: - """Write the generated visitor code into visitor_generated.py.""" - out_path = base_dir / "temporalio" / "bridge" / "visitor_generated.py" + """Write the generated visitor code into _visitor.py.""" + out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" code = gen_workflow_activation_payload_visitor_code() out_path.write_text(code) if __name__ == "__main__": - print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr) + print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) write_generated_visitors_into_visitor_generated_py() + subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py new file mode 100644 index 000000000..e09797cd6 --- /dev/null +++ b/temporalio/bridge/_visitor.py @@ -0,0 +1,421 @@ +from typing import Any, Awaitable, Callable + +from temporalio.api.common.v1.message_pb2 import Payload + + +class PayloadVisitor: + """A visitor for payloads. Applies a function to every payload in a tree of messages.""" + + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + """Creates a new payload visitor.""" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit( + self, f: Callable[[Payload], Awaitable[Payload]], root: Any + ) -> None: + """Visits the given root message with the given function.""" + method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(f, root) + + async def _visit_temporal_api_common_v1_Payload(self, f, o): + o.CopyFrom(await f(o)) + + async def _visit_temporal_api_common_v1_Payloads(self, f, o): + for v in o.payloads: + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, f, o): + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payloads(f, o.details) + + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, f, o): + if o.HasField("last_heartbeat_details"): + await self._visit_temporal_api_common_v1_Payloads( + f, o.last_heartbeat_details + ) + + async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, f, o): + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payloads(f, o.details) + + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, f, o): + if o.HasField("last_heartbeat_details"): + await self._visit_temporal_api_common_v1_Payloads( + f, o.last_heartbeat_details + ) + + async def _visit_temporal_api_failure_v1_Failure(self, f, o): + if o.HasField("encoded_attributes"): + await self._visit_temporal_api_common_v1_Payload(f, o.encoded_attributes) + if o.HasField("application_failure_info"): + await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( + f, o.application_failure_info + ) + if o.HasField("timeout_failure_info"): + await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( + f, o.timeout_failure_info + ) + if o.HasField("canceled_failure_info"): + await self._visit_temporal_api_failure_v1_CanceledFailureInfo( + f, o.canceled_failure_info + ) + if o.HasField("reset_workflow_failure_info"): + await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + f, o.reset_workflow_failure_info + ) + + async def _visit_temporal_api_common_v1_Memo(self, f, o): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_temporal_api_common_v1_SearchAttributes(self, f, o): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, f, o): + for v in o.arguments: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + if o.HasField("continued_failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.continued_failure) + if o.HasField("last_completion_result"): + await self._visit_temporal_api_common_v1_Payloads( + f, o.last_completion_result + ) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(f, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + f, o.search_attributes + ) + + async def _visit_coresdk_workflow_activation_QueryWorkflow(self, f, o): + for v in o.arguments: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_activation_SignalWorkflow(self, f, o): + for v in o.input: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_activity_result_Success(self, f, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(f, o.result) + + async def _visit_coresdk_activity_result_Failure(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_activity_result_Cancellation(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_activity_result_ActivityResolution(self, f, o): + if o.HasField("completed"): + await self._visit_coresdk_activity_result_Success(f, o.completed) + if o.HasField("failed"): + await self._visit_coresdk_activity_result_Failure(f, o.failed) + if o.HasField("cancelled"): + await self._visit_coresdk_activity_result_Cancellation(f, o.cancelled) + + async def _visit_coresdk_workflow_activation_ResolveActivity(self, f, o): + if o.HasField("result"): + await self._visit_coresdk_activity_result_ActivityResolution(f, o.result) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + self, f, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + self, f, o + ): + if o.HasField("cancelled"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + f, o.cancelled + ) + + async def _visit_coresdk_child_workflow_Success(self, f, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(f, o.result) + + async def _visit_coresdk_child_workflow_Failure(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_child_workflow_Cancellation(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, f, o): + if o.HasField("completed"): + await self._visit_coresdk_child_workflow_Success(f, o.completed) + if o.HasField("failed"): + await self._visit_coresdk_child_workflow_Failure(f, o.failed) + if o.HasField("cancelled"): + await self._visit_coresdk_child_workflow_Cancellation(f, o.cancelled) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + self, f, o + ): + if o.HasField("result"): + await self._visit_coresdk_child_workflow_ChildWorkflowResult(f, o.result) + + async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + self, f, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + self, f, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_workflow_activation_DoUpdate(self, f, o): + for v in o.input: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, f, o): + if o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + + async def _visit_coresdk_nexus_NexusOperationResult(self, f, o): + if o.HasField("completed"): + await self._visit_temporal_api_common_v1_Payload(f, o.completed) + if o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + if o.HasField("cancelled"): + await self._visit_temporal_api_failure_v1_Failure(f, o.cancelled) + if o.HasField("timed_out"): + await self._visit_temporal_api_failure_v1_Failure(f, o.timed_out) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, f, o): + if o.HasField("result"): + await self._visit_coresdk_nexus_NexusOperationResult(f, o.result) + + async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, f, o): + if o.HasField("initialize_workflow"): + await self._visit_coresdk_workflow_activation_InitializeWorkflow( + f, o.initialize_workflow + ) + if o.HasField("query_workflow"): + await self._visit_coresdk_workflow_activation_QueryWorkflow( + f, o.query_workflow + ) + if o.HasField("signal_workflow"): + await self._visit_coresdk_workflow_activation_SignalWorkflow( + f, o.signal_workflow + ) + if o.HasField("resolve_activity"): + await self._visit_coresdk_workflow_activation_ResolveActivity( + f, o.resolve_activity + ) + if o.HasField("resolve_child_workflow_execution_start"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + f, o.resolve_child_workflow_execution_start + ) + if o.HasField("resolve_child_workflow_execution"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + f, o.resolve_child_workflow_execution + ) + if o.HasField("resolve_signal_external_workflow"): + await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + f, o.resolve_signal_external_workflow + ) + if o.HasField("resolve_request_cancel_external_workflow"): + await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + f, o.resolve_request_cancel_external_workflow + ) + if o.HasField("do_update"): + await self._visit_coresdk_workflow_activation_DoUpdate(f, o.do_update) + if o.HasField("resolve_nexus_operation_start"): + await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( + f, o.resolve_nexus_operation_start + ) + if o.HasField("resolve_nexus_operation"): + await self._visit_coresdk_workflow_activation_ResolveNexusOperation( + f, o.resolve_nexus_operation + ) + + async def _visit_coresdk_workflow_activation_WorkflowActivation(self, f, o): + for v in o.jobs: + await self._visit_coresdk_workflow_activation_WorkflowActivationJob(f, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata(self, f, o): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(f, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(f, o.details) + + async def _visit_coresdk_workflow_commands_ScheduleActivity(self, f, o): + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.arguments: + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_QuerySuccess(self, f, o): + if o.HasField("response"): + await self._visit_temporal_api_common_v1_Payload(f, o.response) + + async def _visit_coresdk_workflow_commands_QueryResult(self, f, o): + if o.HasField("succeeded"): + await self._visit_coresdk_workflow_commands_QuerySuccess(f, o.succeeded) + if o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, f, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(f, o.result) + + async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + self, f, o + ): + for v in o.arguments: + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, f, o): + for v in o.input: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + self, f, o + ): + for v in o.args: + await self._visit_temporal_api_common_v1_Payload(f, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, f, o): + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + for v in o.arguments: + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + self, f, o + ): + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(f, v) + + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, f, o): + if o.HasField("upserted_memo"): + await self._visit_temporal_api_common_v1_Memo(f, o.upserted_memo) + + async def _visit_coresdk_workflow_commands_UpdateResponse(self, f, o): + if o.HasField("rejected"): + await self._visit_temporal_api_failure_v1_Failure(f, o.rejected) + if o.HasField("completed"): + await self._visit_temporal_api_common_v1_Payload(f, o.completed) + + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, f, o): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payload(f, o.input) + + async def _visit_coresdk_workflow_commands_WorkflowCommand(self, f, o): + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(f, o.user_metadata) + if o.HasField("schedule_activity"): + await self._visit_coresdk_workflow_commands_ScheduleActivity( + f, o.schedule_activity + ) + if o.HasField("respond_to_query"): + await self._visit_coresdk_workflow_commands_QueryResult( + f, o.respond_to_query + ) + if o.HasField("complete_workflow_execution"): + await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + f, o.complete_workflow_execution + ) + if o.HasField("fail_workflow_execution"): + await self._visit_coresdk_workflow_commands_FailWorkflowExecution( + f, o.fail_workflow_execution + ) + if o.HasField("continue_as_new_workflow_execution"): + await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + f, o.continue_as_new_workflow_execution + ) + if o.HasField("start_child_workflow_execution"): + await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( + f, o.start_child_workflow_execution + ) + if o.HasField("signal_external_workflow_execution"): + await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + f, o.signal_external_workflow_execution + ) + if o.HasField("schedule_local_activity"): + await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( + f, o.schedule_local_activity + ) + if o.HasField("upsert_workflow_search_attributes"): + await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + f, o.upsert_workflow_search_attributes + ) + if o.HasField("modify_workflow_properties"): + await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( + f, o.modify_workflow_properties + ) + if o.HasField("update_response"): + await self._visit_coresdk_workflow_commands_UpdateResponse( + f, o.update_response + ) + if o.HasField("schedule_nexus_operation"): + await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( + f, o.schedule_nexus_operation + ) + + async def _visit_coresdk_workflow_completion_Success(self, f, o): + for v in o.commands: + await self._visit_coresdk_workflow_commands_WorkflowCommand(f, v) + + async def _visit_coresdk_workflow_completion_Failure(self, f, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + + async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( + self, f, o + ): + if o.HasField("successful"): + await self._visit_coresdk_workflow_completion_Success(f, o.successful) + if o.HasField("failed"): + await self._visit_coresdk_workflow_completion_Failure(f, o.failed) diff --git a/temporalio/bridge/visitor.py b/temporalio/bridge/visitor.py deleted file mode 100644 index df4bb84ab..000000000 --- a/temporalio/bridge/visitor.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Mapping as AbcMapping -from collections.abc import Sequence as AbcSequence -from typing import Any, Awaitable, Callable - -from google.protobuf.descriptor import FieldDescriptor -from google.protobuf.message import Message - -from temporalio.api.common.v1.message_pb2 import Payload, SearchAttributes - - -class PayloadVisitor: - def __init__( - self, *, skip_search_attributes: bool = False, skip_headers: bool = False - ): - self.skip_search_attributes = skip_search_attributes - self.skip_headers = skip_headers - - async def visit_payloads( - self, f: Callable[[Payload], Awaitable[Payload]], root: Any - ) -> None: - if self.skip_search_attributes and isinstance(root, SearchAttributes): - return - - if isinstance(root, Payload): - root.CopyFrom(await f(root)) - elif isinstance(root, AbcMapping): - for k, v in root.items(): - await self.visit_payloads(f, k) - await self.visit_payloads(f, v) - elif isinstance(root, AbcSequence) and not isinstance( - root, (bytes, bytearray, str) - ): - for o in root: - await self.visit_payloads(f, o) - elif isinstance(root, Message): - await self.visit_message( - f, - root, - ) - - async def visit_message( - self, f: Callable[[Payload], Awaitable[Payload]], root: Message - ) -> None: - for field in root.DESCRIPTOR.fields: - if self.skip_headers and field.name == "headers": - continue - - # Repeated fields (including maps which are represented as repeated messages) - if field.label == FieldDescriptor.LABEL_REPEATED: - value = getattr(root, field.name) - if ( - field.message_type is not None - and field.message_type.GetOptions().map_entry - ): - for k, v in value.items(): - await self.visit_payloads(f, k) - await self.visit_payloads(f, v) - else: - for item in value: - await self.visit_payloads(f, item) - else: - # Only descend into singular message fields if present - if field.type == FieldDescriptor.TYPE_MESSAGE and root.HasField( - field.name - ): - value = getattr(root, field.name) - await self.visit_payloads(f, value) diff --git a/temporalio/bridge/visitor_generated.py b/temporalio/bridge/visitor_generated.py deleted file mode 100644 index 42a8aa2b5..000000000 --- a/temporalio/bridge/visitor_generated.py +++ /dev/null @@ -1,333 +0,0 @@ -from typing import Any, Awaitable, Callable - -from temporalio.api.common.v1.message_pb2 import Payload - - -class PayloadVisitor: - def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False): - self.skip_search_attributes = skip_search_attributes - self.skip_headers = skip_headers - - async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None: - method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_') - method = getattr(self, method_name, None) - if method is not None: - await method(f, root) - - async def visit_temporal_api_common_v1_Payload(self, f, o): - o.CopyFrom(await f(o)) - - async def visit_temporal_api_common_v1_Payloads(self, f, o): - for v in o.payloads: - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_temporal_api_failure_v1_ApplicationFailureInfo(self, f, o): - if o.HasField('details'): - await self.visit_temporal_api_common_v1_Payloads(f, o.details) - - async def visit_temporal_api_failure_v1_TimeoutFailureInfo(self, f, o): - if o.HasField('last_heartbeat_details'): - await self.visit_temporal_api_common_v1_Payloads(f, o.last_heartbeat_details) - - async def visit_temporal_api_failure_v1_CanceledFailureInfo(self, f, o): - if o.HasField('details'): - await self.visit_temporal_api_common_v1_Payloads(f, o.details) - - async def visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, f, o): - if o.HasField('last_heartbeat_details'): - await self.visit_temporal_api_common_v1_Payloads(f, o.last_heartbeat_details) - - async def visit_temporal_api_failure_v1_Failure(self, f, o): - if o.HasField('encoded_attributes'): - await self.visit_temporal_api_common_v1_Payload(f, o.encoded_attributes) - if o.HasField('application_failure_info'): - await self.visit_temporal_api_failure_v1_ApplicationFailureInfo(f, o.application_failure_info) - if o.HasField('timeout_failure_info'): - await self.visit_temporal_api_failure_v1_TimeoutFailureInfo(f, o.timeout_failure_info) - if o.HasField('canceled_failure_info'): - await self.visit_temporal_api_failure_v1_CanceledFailureInfo(f, o.canceled_failure_info) - if o.HasField('reset_workflow_failure_info'): - await self.visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(f, o.reset_workflow_failure_info) - - async def visit_temporal_api_common_v1_Memo(self, f, o): - for v in o.fields.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_temporal_api_common_v1_SearchAttributes(self, f, o): - if self.skip_search_attributes: - return - for v in o.indexed_fields.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_activation_InitializeWorkflow(self, f, o): - for v in o.arguments: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - if o.HasField('continued_failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.continued_failure) - if o.HasField('last_completion_result'): - await self.visit_temporal_api_common_v1_Payloads(f, o.last_completion_result) - if o.HasField('memo'): - await self.visit_temporal_api_common_v1_Memo(f, o.memo) - if o.HasField('search_attributes'): - await self.visit_temporal_api_common_v1_SearchAttributes(f, o.search_attributes) - - async def visit_coresdk_workflow_activation_QueryWorkflow(self, f, o): - for v in o.arguments: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_activation_SignalWorkflow(self, f, o): - for v in o.input: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_activity_result_Success(self, f, o): - if o.HasField('result'): - await self.visit_temporal_api_common_v1_Payload(f, o.result) - - async def visit_coresdk_activity_result_Failure(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_activity_result_Cancellation(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_activity_result_ActivityResolution(self, f, o): - if o.HasField('completed'): - await self.visit_coresdk_activity_result_Success(f, o.completed) - if o.HasField('failed'): - await self.visit_coresdk_activity_result_Failure(f, o.failed) - if o.HasField('cancelled'): - await self.visit_coresdk_activity_result_Cancellation(f, o.cancelled) - - async def visit_coresdk_workflow_activation_ResolveActivity(self, f, o): - if o.HasField('result'): - await self.visit_coresdk_activity_result_ActivityResolution(f, o.result) - - async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(self, f, o): - if o.HasField('cancelled'): - await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled(f, o.cancelled) - - async def visit_coresdk_child_workflow_Success(self, f, o): - if o.HasField('result'): - await self.visit_temporal_api_common_v1_Payload(f, o.result) - - async def visit_coresdk_child_workflow_Failure(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_child_workflow_Cancellation(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_child_workflow_ChildWorkflowResult(self, f, o): - if o.HasField('completed'): - await self.visit_coresdk_child_workflow_Success(f, o.completed) - if o.HasField('failed'): - await self.visit_coresdk_child_workflow_Failure(f, o.failed) - if o.HasField('cancelled'): - await self.visit_coresdk_child_workflow_Cancellation(f, o.cancelled) - - async def visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(self, f, o): - if o.HasField('result'): - await self.visit_coresdk_child_workflow_ChildWorkflowResult(f, o.result) - - async def visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_workflow_activation_DoUpdate(self, f, o): - for v in o.input: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, f, o): - if o.HasField('failed'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failed) - - async def visit_coresdk_nexus_NexusOperationResult(self, f, o): - if o.HasField('completed'): - await self.visit_temporal_api_common_v1_Payload(f, o.completed) - if o.HasField('failed'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failed) - if o.HasField('cancelled'): - await self.visit_temporal_api_failure_v1_Failure(f, o.cancelled) - if o.HasField('timed_out'): - await self.visit_temporal_api_failure_v1_Failure(f, o.timed_out) - - async def visit_coresdk_workflow_activation_ResolveNexusOperation(self, f, o): - if o.HasField('result'): - await self.visit_coresdk_nexus_NexusOperationResult(f, o.result) - - async def visit_coresdk_workflow_activation_WorkflowActivationJob(self, f, o): - if o.HasField('initialize_workflow'): - await self.visit_coresdk_workflow_activation_InitializeWorkflow(f, o.initialize_workflow) - if o.HasField('query_workflow'): - await self.visit_coresdk_workflow_activation_QueryWorkflow(f, o.query_workflow) - if o.HasField('signal_workflow'): - await self.visit_coresdk_workflow_activation_SignalWorkflow(f, o.signal_workflow) - if o.HasField('resolve_activity'): - await self.visit_coresdk_workflow_activation_ResolveActivity(f, o.resolve_activity) - if o.HasField('resolve_child_workflow_execution_start'): - await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart(f, o.resolve_child_workflow_execution_start) - if o.HasField('resolve_child_workflow_execution'): - await self.visit_coresdk_workflow_activation_ResolveChildWorkflowExecution(f, o.resolve_child_workflow_execution) - if o.HasField('resolve_signal_external_workflow'): - await self.visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow(f, o.resolve_signal_external_workflow) - if o.HasField('resolve_request_cancel_external_workflow'): - await self.visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow(f, o.resolve_request_cancel_external_workflow) - if o.HasField('do_update'): - await self.visit_coresdk_workflow_activation_DoUpdate(f, o.do_update) - if o.HasField('resolve_nexus_operation_start'): - await self.visit_coresdk_workflow_activation_ResolveNexusOperationStart(f, o.resolve_nexus_operation_start) - if o.HasField('resolve_nexus_operation'): - await self.visit_coresdk_workflow_activation_ResolveNexusOperation(f, o.resolve_nexus_operation) - - async def visit_coresdk_workflow_activation_WorkflowActivation(self, f, o): - for v in o.jobs: - await self.visit_coresdk_workflow_activation_WorkflowActivationJob(f, v) - - async def visit_temporal_api_sdk_v1_UserMetadata(self, f, o): - if o.HasField('summary'): - await self.visit_temporal_api_common_v1_Payload(f, o.summary) - if o.HasField('details'): - await self.visit_temporal_api_common_v1_Payload(f, o.details) - - async def visit_coresdk_workflow_commands_ScheduleActivity(self, f, o): - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.arguments: - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_QuerySuccess(self, f, o): - if o.HasField('response'): - await self.visit_temporal_api_common_v1_Payload(f, o.response) - - async def visit_coresdk_workflow_commands_QueryResult(self, f, o): - if o.HasField('succeeded'): - await self.visit_coresdk_workflow_commands_QuerySuccess(f, o.succeeded) - if o.HasField('failed'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failed) - - async def visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, f, o): - if o.HasField('result'): - await self.visit_temporal_api_common_v1_Payload(f, o.result) - - async def visit_coresdk_workflow_commands_FailWorkflowExecution(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(self, f, o): - for v in o.arguments: - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.memo.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.search_attributes.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, f, o): - for v in o.input: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.memo.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.search_attributes.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(self, f, o): - for v in o.args: - await self.visit_temporal_api_common_v1_Payload(f, v) - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_ScheduleLocalActivity(self, f, o): - if not self.skip_headers: - for v in o.headers.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - for v in o.arguments: - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(self, f, o): - for v in o.search_attributes.values(): - await self.visit_temporal_api_common_v1_Payload(f, v) - - async def visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, f, o): - if o.HasField('upserted_memo'): - await self.visit_temporal_api_common_v1_Memo(f, o.upserted_memo) - - async def visit_coresdk_workflow_commands_UpdateResponse(self, f, o): - if o.HasField('rejected'): - await self.visit_temporal_api_failure_v1_Failure(f, o.rejected) - if o.HasField('completed'): - await self.visit_temporal_api_common_v1_Payload(f, o.completed) - - async def visit_coresdk_workflow_commands_ScheduleNexusOperation(self, f, o): - if o.HasField('input'): - await self.visit_temporal_api_common_v1_Payload(f, o.input) - - async def visit_coresdk_workflow_commands_WorkflowCommand(self, f, o): - if o.HasField('user_metadata'): - await self.visit_temporal_api_sdk_v1_UserMetadata(f, o.user_metadata) - if o.HasField('schedule_activity'): - await self.visit_coresdk_workflow_commands_ScheduleActivity(f, o.schedule_activity) - if o.HasField('respond_to_query'): - await self.visit_coresdk_workflow_commands_QueryResult(f, o.respond_to_query) - if o.HasField('complete_workflow_execution'): - await self.visit_coresdk_workflow_commands_CompleteWorkflowExecution(f, o.complete_workflow_execution) - if o.HasField('fail_workflow_execution'): - await self.visit_coresdk_workflow_commands_FailWorkflowExecution(f, o.fail_workflow_execution) - if o.HasField('continue_as_new_workflow_execution'): - await self.visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution(f, o.continue_as_new_workflow_execution) - if o.HasField('start_child_workflow_execution'): - await self.visit_coresdk_workflow_commands_StartChildWorkflowExecution(f, o.start_child_workflow_execution) - if o.HasField('signal_external_workflow_execution'): - await self.visit_coresdk_workflow_commands_SignalExternalWorkflowExecution(f, o.signal_external_workflow_execution) - if o.HasField('schedule_local_activity'): - await self.visit_coresdk_workflow_commands_ScheduleLocalActivity(f, o.schedule_local_activity) - if o.HasField('upsert_workflow_search_attributes'): - await self.visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes(f, o.upsert_workflow_search_attributes) - if o.HasField('modify_workflow_properties'): - await self.visit_coresdk_workflow_commands_ModifyWorkflowProperties(f, o.modify_workflow_properties) - if o.HasField('update_response'): - await self.visit_coresdk_workflow_commands_UpdateResponse(f, o.update_response) - if o.HasField('schedule_nexus_operation'): - await self.visit_coresdk_workflow_commands_ScheduleNexusOperation(f, o.schedule_nexus_operation) - - async def visit_coresdk_workflow_completion_Success(self, f, o): - for v in o.commands: - await self.visit_coresdk_workflow_commands_WorkflowCommand(f, v) - - async def visit_coresdk_workflow_completion_Failure(self, f, o): - if o.HasField('failure'): - await self.visit_temporal_api_failure_v1_Failure(f, o.failure) - - async def visit_coresdk_workflow_completion_WorkflowActivationCompletion(self, f, o): - if o.HasField('successful'): - await self.visit_coresdk_workflow_completion_Success(f, o.successful) - if o.HasField('failed'): - await self.visit_coresdk_workflow_completion_Failure(f, o.failed) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 189e838df..1a9a8026d 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -36,11 +36,11 @@ import temporalio.converter import temporalio.exceptions from temporalio.api.common.v1.message_pb2 import Payload +from temporalio.bridge._visitor import PayloadVisitor from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) from temporalio.bridge.temporal_sdk_bridge import PollShutdownError # type: ignore -from temporalio.bridge.visitor import PayloadVisitor @dataclass @@ -378,7 +378,7 @@ async def visitor(payload: Payload) -> Payload: await PayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers - ).visit_message(visitor, act) + ).visit(visitor, act) async def encode_completion( @@ -393,4 +393,4 @@ async def visitor(payload: Payload) -> Payload: await PayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit_message(visitor, comp) + ).visit(visitor, comp) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index ee5df6cba..9dacc6ccf 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1709,9 +1709,14 @@ def workflow_last_completion_result( return None print("Payload:", self._last_completion_result.payloads[0]) - return self._payload_converter.from_payload( - self._last_completion_result.payloads[0], type_hint - ) + if type_hint is None: + return self._payload_converter.from_payload( + self._last_completion_result.payloads[0] + ) + else: + return self._payload_converter.from_payload( + self._last_completion_result.payloads[0], type_hint + ) #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". @@ -2793,10 +2798,7 @@ def _apply_schedule_command( v.start_to_close_timeout.FromTimedelta(self._input.start_to_close_timeout) if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) - if self._input.summary: - command.user_metadata.summary.CopyFrom( - self._instance._payload_converter.to_payload(self._input.summary) - ) + v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 19d280ca2..d669417ed 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -1044,19 +1044,6 @@ def get_current_details() -> str: return _Runtime.current().workflow_get_current_details() -@overload -def get_last_completion_result(type_hint: Type[ParamType]) -> Optional[ParamType]: ... - - -def get_last_completion_result(type_hint: Optional[Type] = None) -> Optional[Any]: - """Get the current details of the workflow which may appear in the UI/CLI. - Unlike static details set at start, this value can be updated throughout - the life of the workflow and is independent of the static details. - This can be in Temporal markdown format and can span multiple lines. - """ - return _Runtime.current().workflow_last_completion_result(type_hint) - - def set_current_details(description: str) -> None: """Set the current details of the workflow which may appear in the UI/CLI. Unlike static details set at start, this value can be updated throughout diff --git a/tests/test_client.py b/tests/test_client.py index 773467163..f282e3c2a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1502,42 +1502,3 @@ async def test_cloud_client_simple(): GetNamespaceRequest(namespace=os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"]) ) assert os.environ["TEMPORAL_CLIENT_CLOUD_NAMESPACE"] == result.namespace.namespace - - -@workflow.defn -class LastCompletionResultWorkflow: - @workflow.run - async def run(self) -> str: - last_result = workflow.get_last_completion_result(type_hint=str) - if last_result is not None: - return "From last completion:" + last_result - else: - return "My First Result" - - -async def test_schedule_last_completion_result( - client: Client, env: WorkflowEnvironment -): - if env.supports_time_skipping: - pytest.skip("Java test server doesn't support schedules") - - async with new_worker(client, LastCompletionResultWorkflow) as worker: - handle = await client.create_schedule( - f"schedule-{uuid.uuid4()}", - Schedule( - action=ScheduleActionStartWorkflow( - "LastCompletionResultWorkflow", - id=f"workflow-{uuid.uuid4()}", - task_queue=worker.task_queue, - ), - spec=ScheduleSpec(), - ), - ) - await handle.trigger() - await asyncio.sleep(1) - await handle.trigger() - await asyncio.sleep(1) - print(await handle.describe()) - - await handle.delete() - assert False diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index 87dd68e8f..b36732f36 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -7,6 +7,7 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata +from temporalio.bridge._visitor import PayloadVisitor from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation, @@ -25,7 +26,6 @@ Success, WorkflowActivationCompletion, ) -from temporalio.bridge.visitor_generated import PayloadVisitor async def test_workflow_activation_completion(): From a0f366b718429a3dd90202f1c5324093c7979633 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 2 Sep 2025 17:08:26 -0700 Subject: [PATCH 05/14] Cleanup --- scripts/_proto/Dockerfile | 1 - temporalio/bridge/worker.py | 89 ------------------- temporalio/worker/_workflow.py | 1 - temporalio/worker/_workflow_instance.py | 33 ------- temporalio/worker/workflow_sandbox/_runner.py | 3 - temporalio/workflow.py | 5 -- tests/test_client.py | 1 - tests/worker/test_workflow.py | 3 - 8 files changed, 136 deletions(-) diff --git a/scripts/_proto/Dockerfile b/scripts/_proto/Dockerfile index 5227d883a..47f3c60dc 100644 --- a/scripts/_proto/Dockerfile +++ b/scripts/_proto/Dockerfile @@ -10,7 +10,6 @@ COPY ./ ./ RUN mkdir -p ./temporalio/api RUN uv add "protobuf<4" RUN uv sync --all-extras -RUN poe build-develop RUN poe gen-protos CMD cp -r ./temporalio/api/* /api_new && cp -r ./temporalio/bridge/proto/* /bridge_new diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 1a9a8026d..6b0b79e3e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -20,7 +20,6 @@ ) import google.protobuf.internal.containers -from google.protobuf.message import Message from typing_extensions import TypeAlias import temporalio.api.common.v1 @@ -278,94 +277,6 @@ async def finalize_shutdown(self) -> None: await ref.finalize_shutdown() -# See https://mypy.readthedocs.io/en/stable/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - PayloadContainer: TypeAlias = ( - google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - temporalio.api.common.v1.Payload - ] - ) -else: - PayloadContainer: TypeAlias = ( - google.protobuf.internal.containers.RepeatedCompositeFieldContainer - ) - - -async def _apply_to_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to headers.""" - for payload in headers.values(): - new_payload = (await cb([payload]))[0] - payload.CopyFrom(new_payload) - - -async def _decode_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode headers with the given codec.""" - return await _apply_to_headers(headers, codec.decode) - - -async def _encode_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - codec: temporalio.converter.PayloadCodec, -) -> None: - """Encode headers with the given codec.""" - return await _apply_to_headers(headers, codec.encode) - - -async def _apply_to_payloads( - payloads: PayloadContainer, - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to payloads.""" - if len(payloads) == 0: - return - new_payloads = await cb(payloads) - if new_payloads is payloads: - return - del payloads[:] - # TODO(cretz): Copy too expensive? - payloads.extend(new_payloads) - - -async def _apply_to_payload( - payload: temporalio.api.common.v1.Payload, - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to payload.""" - new_payload = (await cb([payload]))[0] - payload.CopyFrom(new_payload) - - -async def _decode_payloads( - payloads: PayloadContainer, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode payloads with the given codec.""" - return await _apply_to_payloads(payloads, codec.decode) - - -async def _decode_payload( - payload: temporalio.api.common.v1.Payload, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode a payload with the given codec.""" - return await _apply_to_payload(payload, codec.decode) - - async def decode_activation( act: temporalio.bridge.proto.workflow_activation.WorkflowActivation, codec: temporalio.converter.PayloadCodec, diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index cd089c7bc..c7b89206c 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -563,7 +563,6 @@ def _create_workflow_instance( extern_functions=self._extern_functions, disable_eager_activity_execution=self._disable_eager_activity_execution, worker_level_failure_exception_types=self._workflow_failure_exception_types, - last_completion_result=init.last_completion_result, ) if defn.sandboxed: return self._workflow_runner.create_instance(det) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 9dacc6ccf..c93155672 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -64,7 +64,6 @@ import temporalio.workflow from temporalio.service import __version__ -from ..api.common.v1.message_pb2 import Payloads from ._interceptor import ( ContinueAsNewInput, ExecuteWorkflowInput, @@ -144,7 +143,6 @@ class WorkflowInstanceDetails: extern_functions: Mapping[str, Callable] disable_eager_activity_execution: bool worker_level_failure_exception_types: Sequence[Type[BaseException]] - last_completion_result: Payloads class WorkflowInstance(ABC): @@ -322,8 +320,6 @@ def __init__(self, det: WorkflowInstanceDetails) -> None: # metadata query self._current_details = "" - self._last_completion_result = det.last_completion_result - # The versioning behavior of this workflow, as established by annotation or by the dynamic # config function. Is only set once upon initialization. self._versioning_behavior: Optional[temporalio.common.VersioningBehavior] = None @@ -1690,34 +1686,6 @@ def workflow_set_current_details(self, details: str): self._assert_not_read_only("set current details") self._current_details = details - def workflow_last_completion_result( - self, type_hint: Optional[Type] - ) -> Optional[Any]: - print( - "workflow_last_completion_result: ", - self._last_completion_result, - type(self._last_completion_result), - "payload length:", - len(self._last_completion_result.payloads), - ) - if len(self._last_completion_result.payloads) == 0: - return None - elif len(self._last_completion_result.payloads) > 1: - warnings.warn( - f"Expected single last completion result, got {len(self._last_completion_result.payloads)}" - ) - return None - - print("Payload:", self._last_completion_result.payloads[0]) - if type_hint is None: - return self._payload_converter.from_payload( - self._last_completion_result.payloads[0] - ) - else: - return self._payload_converter.from_payload( - self._last_completion_result.payloads[0], type_hint - ) - #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". @@ -2798,7 +2766,6 @@ def _apply_schedule_command( v.start_to_close_timeout.FromTimedelta(self._input.start_to_close_timeout) if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) - v.cancellation_type = cast( temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), diff --git a/temporalio/worker/workflow_sandbox/_runner.py b/temporalio/worker/workflow_sandbox/_runner.py index 960d02f6a..ba1a7f3ce 100644 --- a/temporalio/worker/workflow_sandbox/_runner.py +++ b/temporalio/worker/workflow_sandbox/_runner.py @@ -18,8 +18,6 @@ import temporalio.worker._workflow_instance import temporalio.workflow -from ...api.common.v1.message_pb2 import Payloads - # Workflow instance has to be relative import from .._workflow_instance import ( UnsandboxedWorkflowRunner, @@ -86,7 +84,6 @@ def prepare_workflow(self, defn: temporalio.workflow._Definition) -> None: extern_functions={}, disable_eager_activity_execution=False, worker_level_failure_exception_types=self._worker_level_failure_exception_types, - last_completion_result=Payloads(), ), ) diff --git a/temporalio/workflow.py b/temporalio/workflow.py index d669417ed..423d5289b 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -897,11 +897,6 @@ def workflow_get_current_details(self) -> str: ... @abstractmethod def workflow_set_current_details(self, details: str): ... - @abstractmethod - def workflow_last_completion_result( - self, type_hint: Optional[Type] - ) -> Optional[Any]: ... - _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( "__temporal_current_update_info" diff --git a/tests/test_client.py b/tests/test_client.py index f282e3c2a..9c33e9e1c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,3 @@ -import asyncio import dataclasses import json import os diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 279f3f0f8..e97bf3e02 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8275,7 +8275,6 @@ async def test_workflow_headers_with_codec( "Temporal", id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=1), ) assert await workflow_handle.result() == "Hello, Temporal!" @@ -8289,7 +8288,6 @@ async def test_workflow_headers_with_codec( SignalAndQueryWorkflow.run, id=f"workflow-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=1), ) # Simple signals and queries @@ -8329,4 +8327,3 @@ async def test_workflow_headers_with_codec( assert headers["foo"].data == b"bar" else: assert headers["foo"].data != b"bar" - assert False From bc0e8a217ddde496f632f8ff5a2a078496662537 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 3 Sep 2025 08:23:22 -0700 Subject: [PATCH 06/14] Fix generator method name --- scripts/gen_visitors.py | 4 +++- temporalio/bridge/_visitor.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index e017e6d5a..882a61362 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -198,10 +198,12 @@ async def visit( self, f: Callable[[Payload], Awaitable[Payload]], root: Any ) -> None: \"\"\"Visits the given root message with the given function.\"\"\" - method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: await method(f, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") """ diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index e09797cd6..37502ab1b 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -17,10 +17,12 @@ async def visit( self, f: Callable[[Payload], Awaitable[Payload]], root: Any ) -> None: """Visits the given root message with the given function.""" - method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: await method(f, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") async def _visit_temporal_api_common_v1_Payload(self, f, o): o.CopyFrom(await f(o)) From 681b54a1d5bbb2d6dd5334d46624ec29694c3023 Mon Sep 17 00:00:00 2001 From: tconley1428 Date: Wed, 3 Sep 2025 08:42:47 -0700 Subject: [PATCH 07/14] Update gen_visitors.py --- scripts/gen_visitors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index 882a61362..3b9d7ed5e 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -23,7 +23,7 @@ def gen_workflow_activation_payload_visitor_code() -> str: The generated code defines async visitor functions for each reachable protobuf message type starting from WorkflowActivation, including support for repeated fields and map entries, and a convenience entrypoint - function `visit_workflow_activation_payloads`. + function `visit`. """ def name_for(desc: Descriptor) -> str: From 6f9fd187f77f3ac4f0e451016128a02770089e94 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 3 Sep 2025 11:04:50 -0700 Subject: [PATCH 08/14] Allow payload mutation --- scripts/gen_visitors.py | 233 +++++++++++++---------- temporalio/bridge/_visitor.py | 346 +++++++++++++++++----------------- temporalio/bridge/worker.py | 35 ++-- tests/worker/test_visitor.py | 118 ++++++------ 4 files changed, 399 insertions(+), 333 deletions(-) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index 3b9d7ed5e..7f8eb6e68 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -1,6 +1,7 @@ import subprocess import sys from pathlib import Path +from typing import Optional, Tuple from google.protobuf.descriptor import Descriptor, FieldDescriptor @@ -15,75 +16,98 @@ base_dir = Path(__file__).parent.parent -def gen_workflow_activation_payload_visitor_code() -> str: - """ - Generate Python source code that, given a function f(Payload) -> Payload, - applies it to every Payload contained within a WorkflowActivation tree. - - The generated code defines async visitor functions for each reachable - protobuf message type starting from WorkflowActivation, including support - for repeated fields and map entries, and a convenience entrypoint - function `visit`. - """ - - def name_for(desc: Descriptor) -> str: - # Use fully-qualified name to avoid collisions; replace dots with underscores - return desc.full_name.replace(".", "_") - - def emit_loop( - field_name: str, - iter_expr: str, - var_name: str, - child_method: str, - ) -> str: - # Helper to emit a for-loop over a collection with optional headers guard +def name_for(desc: Descriptor) -> str: + # Use fully-qualified name to avoid collisions; replace dots with underscores + return desc.full_name.replace(".", "_") + + +def emit_loop( + field_name: str, + iter_expr: str, + child_method: str, +) -> str: + # Helper to emit a for-loop over a collection with optional headers guard + if field_name == "headers": + return f"""\ + if not self.skip_headers: + for v in {iter_expr}: + await self._visit_{child_method}(fs, v)""" + else: + return f"""\ + for v in {iter_expr}: + await self._visit_{child_method}(fs, v)""" + + +def emit_singular( + field_name: str, access_expr: str, child_method: str, check_presence: bool +) -> str: + # Helper to emit a singular field visit with presence check and optional headers guard + if check_presence: if field_name == "headers": return f"""\ if not self.skip_headers: - for {var_name} in {iter_expr}: - await self._visit_{child_method}(f, {var_name})""" + if o.HasField("{field_name}"): + await self._visit_{child_method}(fs, {access_expr})""" else: return f"""\ - for {var_name} in {iter_expr}: - await self._visit_{child_method}(f, {var_name})""" - - def emit_singular(field_name: str, access_expr: str, child_method: str) -> str: - # Helper to emit a singular field visit with presence check and optional headers guard + if o.HasField("{field_name}"): + await self._visit_{child_method}(fs, {access_expr})""" + else: if field_name == "headers": return f"""\ if not self.skip_headers: - if o.HasField("{field_name}"): - await self._visit_{child_method}(f, {access_expr})""" + await self._visit_{child_method}(fs, {access_expr})""" else: return f"""\ - if o.HasField("{field_name}"): - await self._visit_{child_method}(f, {access_expr})""" - - # Track which message descriptors have visitor methods generated - generated: dict[str, bool] = {} - in_progress: set[str] = set() - methods: list[str] = [] + await self._visit_{child_method}(fs, {access_expr})""" + + +class VisitorGenerator: + def __init__(self): + # Track which message descriptors have visitor methods generated + self.generated: dict[str, bool] = { + Payload.DESCRIPTOR.full_name: True, + Payloads.DESCRIPTOR.full_name: True, + } + self.in_progress: set[str] = set() + self.methods: list[str] = [ + """ async def _visit_temporal_api_common_v1_Payload(self, fs, o): + await fs.visit_payload(o) + """, + """ async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + await fs.visit_payloads(o.payloads) + """, + """ async def _visit_payload_container(self, fs, o): + await fs.visit_payloads(o) + """, + ] + + def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]: + # Special case for repeated payloads, handle them directly + if child_desc.full_name == Payload.DESCRIPTOR.full_name: + return emit_singular(field.name, iter_expr, "payload_container", False) + else: + child_needed = self.walk(child_desc) + if child_needed: + return emit_loop( + field.name, + iter_expr, + name_for(child_desc), + ) + else: + return None - def walk(desc: Descriptor) -> bool: + def walk(self, desc: Descriptor) -> bool: key = desc.full_name - if key in generated: - return generated[key] - if key in in_progress: + if key in self.generated: + return self.generated[key] + if key in self.in_progress: # Break cycles; if another path proves this node needed, we'll revisit return False - if desc.full_name == Payload.DESCRIPTOR.full_name: - generated[key] = True - methods.append( - """ async def _visit_temporal_api_common_v1_Payload(self, f, o): - o.CopyFrom(await f(o)) -""" - ) - return True - needed = False - in_progress.add(key) - lines: list[str] = [f" async def _visit_{name_for(desc)}(self, f, o):"] + self.in_progress.add(key) + lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] # If this is the SearchAttributes message, allow skipping if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: lines.append(" if self.skip_search_attributes:") @@ -99,91 +123,96 @@ def walk(desc: Descriptor) -> bool: field.message_type is not None and field.message_type.GetOptions().map_entry ): - entry_desc = field.message_type - key_fd = entry_desc.fields_by_name.get("key") - val_fd = entry_desc.fields_by_name.get("value") - + val_fd = field.message_type.fields_by_name.get("value") if ( val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE ): child_desc = val_fd.message_type - child_needed = walk(child_desc) - needed |= child_needed + child_needed = self.walk(child_desc) if child_needed: + needed = True lines.append( emit_loop( field.name, f"o.{field.name}.values()", - "v", name_for(child_desc), ) ) + key_fd = field.message_type.fields_by_name.get("key") if ( key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE ): - key_desc = key_fd.message_type - child_needed = walk(key_desc) - needed |= child_needed + child_desc = key_fd.message_type + child_needed = self.walk(child_desc) if child_needed: + needed = True lines.append( emit_loop( field.name, f"o.{field.name}.keys()", - "k", - name_for(key_desc), + name_for(child_desc), ) ) else: - child_desc = field.message_type - child_needed = walk(child_desc) - needed |= child_needed - if child_needed: - lines.append( - emit_loop( - field.name, - f"o.{field.name}", - "v", - name_for(child_desc), - ) - ) + child = self.check_repeated( + field.message_type, field, f"o.{field.name}" + ) + if child is not None: + needed = True + lines.append(child) else: child_desc = field.message_type - child_needed = walk(child_desc) + child_needed = self.walk(child_desc) needed |= child_needed if child_needed: lines.append( emit_singular( - field.name, f"o.{field.name}", name_for(child_desc) + field.name, f"o.{field.name}", name_for(child_desc), True ) ) - generated[key] = needed - in_progress.discard(key) + self.generated[key] = needed + self.in_progress.discard(key) if needed: - methods.append("\n".join(lines) + "\n") + self.methods.append("\n".join(lines) + "\n") return needed - # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, - # and all messages from selected API modules - roots: list[Descriptor] = [ - WorkflowActivation.DESCRIPTOR, - WorkflowActivationCompletion.DESCRIPTOR, - ] + def generate(self, roots: list[Descriptor]) -> str: + """ + Generate Python source code that, given a function f(Payload) -> Payload, + applies it to every Payload contained within a WorkflowActivation tree. - # We avoid importing google.api deps in service protos; expand by walking from - # WorkflowActivationCompletion root which references many command messages. + The generated code defines async visitor functions for each reachable + protobuf message type starting from WorkflowActivation, including support + for repeated fields and map entries, and a convenience entrypoint + function `visit`. + """ - for r in roots: - walk(r) + # We avoid importing google.api deps in service protos; expand by walking from + # WorkflowActivationCompletion root which references many command messages. + for r in roots: + self.walk(r) - header = """ -from typing import Any, Awaitable, Callable + header = """ +import abc +from typing import Any, MutableSequence from temporalio.api.common.v1.message_pb2 import Payload +class VisitorFunctions(abc.ABC): + \"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\"\"\" + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + \"\"\"Called when encountering a single payload.\"\"\" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + \"\"\"Called when encountering multiple payloads together.\"\"\" + raise NotImplementedError() class PayloadVisitor: \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\" @@ -195,25 +224,33 @@ def __init__( self.skip_headers = skip_headers async def visit( - self, f: Callable[[Payload], Awaitable[Payload]], root: Any + self, fs: VisitorFunctions, root: Any ) -> None: \"\"\"Visits the given root message with the given function.\"\"\" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: - await method(f, root) + await method(fs, root) else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") """ - return header + "\n".join(methods) + return header + "\n".join(self.methods) def write_generated_visitors_into_visitor_generated_py() -> None: """Write the generated visitor code into _visitor.py.""" out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" - code = gen_workflow_activation_payload_visitor_code() + + # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, + # and all messages from selected API modules + roots: list[Descriptor] = [ + WorkflowActivation.DESCRIPTOR, + WorkflowActivationCompletion.DESCRIPTOR, + ] + + code = VisitorGenerator().generate(roots) out_path.write_text(code) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 37502ab1b..e99087ecf 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,8 +1,23 @@ -from typing import Any, Awaitable, Callable +import abc +from typing import Any, MutableSequence from temporalio.api.common.v1.message_pb2 import Payload +class VisitorFunctions(abc.ABC): + """Set of functions which can be called by the visitor. Allows handling payloads as a sequence.""" + + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + """Called when encountering a single payload.""" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + """Called when encountering multiple payloads together.""" + raise NotImplementedError() + + class PayloadVisitor: """A visitor for payloads. Applies a function to every payload in a tree of messages.""" @@ -13,411 +28,404 @@ def __init__( self.skip_search_attributes = skip_search_attributes self.skip_headers = skip_headers - async def visit( - self, f: Callable[[Payload], Awaitable[Payload]], root: Any - ) -> None: + async def visit(self, fs: VisitorFunctions, root: Any) -> None: """Visits the given root message with the given function.""" method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") method = getattr(self, method_name, None) if method is not None: - await method(f, root) + await method(fs, root) else: raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") - async def _visit_temporal_api_common_v1_Payload(self, f, o): - o.CopyFrom(await f(o)) + async def _visit_temporal_api_common_v1_Payload(self, fs, o): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + await fs.visit_payloads(o.payloads) - async def _visit_temporal_api_common_v1_Payloads(self, f, o): - for v in o.payloads: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_payload_container(self, fs, o): + await fs.visit_payloads(o) - async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, f, o): + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): if o.HasField("details"): - await self._visit_temporal_api_common_v1_Payloads(f, o.details) + await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, f, o): + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( - f, o.last_heartbeat_details + fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, f, o): + async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): if o.HasField("details"): - await self._visit_temporal_api_common_v1_Payloads(f, o.details) + await self._visit_temporal_api_common_v1_Payloads(fs, o.details) - async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, f, o): + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): if o.HasField("last_heartbeat_details"): await self._visit_temporal_api_common_v1_Payloads( - f, o.last_heartbeat_details + fs, o.last_heartbeat_details ) - async def _visit_temporal_api_failure_v1_Failure(self, f, o): + async def _visit_temporal_api_failure_v1_Failure(self, fs, o): if o.HasField("encoded_attributes"): - await self._visit_temporal_api_common_v1_Payload(f, o.encoded_attributes) + await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) if o.HasField("application_failure_info"): await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( - f, o.application_failure_info + fs, o.application_failure_info ) if o.HasField("timeout_failure_info"): await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( - f, o.timeout_failure_info + fs, o.timeout_failure_info ) if o.HasField("canceled_failure_info"): await self._visit_temporal_api_failure_v1_CanceledFailureInfo( - f, o.canceled_failure_info + fs, o.canceled_failure_info ) if o.HasField("reset_workflow_failure_info"): await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( - f, o.reset_workflow_failure_info + fs, o.reset_workflow_failure_info ) - async def _visit_temporal_api_common_v1_Memo(self, f, o): + async def _visit_temporal_api_common_v1_Memo(self, fs, o): for v in o.fields.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_temporal_api_common_v1_SearchAttributes(self, f, o): + async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): if self.skip_search_attributes: return for v in o.indexed_fields.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, f, o): - for v in o.arguments: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) if o.HasField("continued_failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.continued_failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) if o.HasField("last_completion_result"): await self._visit_temporal_api_common_v1_Payloads( - f, o.last_completion_result + fs, o.last_completion_result ) if o.HasField("memo"): - await self._visit_temporal_api_common_v1_Memo(f, o.memo) + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) if o.HasField("search_attributes"): await self._visit_temporal_api_common_v1_SearchAttributes( - f, o.search_attributes + fs, o.search_attributes ) - async def _visit_coresdk_workflow_activation_QueryWorkflow(self, f, o): - for v in o.arguments: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.arguments) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_SignalWorkflow(self, f, o): - for v in o.input: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_activity_result_Success(self, f, o): + async def _visit_coresdk_activity_result_Success(self, fs, o): if o.HasField("result"): - await self._visit_temporal_api_common_v1_Payload(f, o.result) + await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_activity_result_Failure(self, f, o): + async def _visit_coresdk_activity_result_Failure(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_Cancellation(self, f, o): + async def _visit_coresdk_activity_result_Cancellation(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_activity_result_ActivityResolution(self, f, o): + async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): if o.HasField("completed"): - await self._visit_coresdk_activity_result_Success(f, o.completed) + await self._visit_coresdk_activity_result_Success(fs, o.completed) if o.HasField("failed"): - await self._visit_coresdk_activity_result_Failure(f, o.failed) + await self._visit_coresdk_activity_result_Failure(fs, o.failed) if o.HasField("cancelled"): - await self._visit_coresdk_activity_result_Cancellation(f, o.cancelled) + await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) - async def _visit_coresdk_workflow_activation_ResolveActivity(self, f, o): + async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): if o.HasField("result"): - await self._visit_coresdk_activity_result_ActivityResolution(f, o.result) + await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - self, f, o + self, fs, o ): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - self, f, o + self, fs, o ): if o.HasField("cancelled"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( - f, o.cancelled + fs, o.cancelled ) - async def _visit_coresdk_child_workflow_Success(self, f, o): + async def _visit_coresdk_child_workflow_Success(self, fs, o): if o.HasField("result"): - await self._visit_temporal_api_common_v1_Payload(f, o.result) + await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_child_workflow_Failure(self, f, o): + async def _visit_coresdk_child_workflow_Failure(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_Cancellation(self, f, o): + async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, f, o): + async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): if o.HasField("completed"): - await self._visit_coresdk_child_workflow_Success(f, o.completed) + await self._visit_coresdk_child_workflow_Success(fs, o.completed) if o.HasField("failed"): - await self._visit_coresdk_child_workflow_Failure(f, o.failed) + await self._visit_coresdk_child_workflow_Failure(fs, o.failed) if o.HasField("cancelled"): - await self._visit_coresdk_child_workflow_Cancellation(f, o.cancelled) + await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - self, f, o + self, fs, o ): if o.HasField("result"): - await self._visit_coresdk_child_workflow_ChildWorkflowResult(f, o.result) + await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - self, f, o + self, fs, o ): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - self, f, o + self, fs, o ): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) - async def _visit_coresdk_workflow_activation_DoUpdate(self, f, o): - for v in o.input: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart(self, f, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( + self, fs, o + ): if o.HasField("failed"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_nexus_NexusOperationResult(self, f, o): + async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): if o.HasField("completed"): - await self._visit_temporal_api_common_v1_Payload(f, o.completed) + await self._visit_temporal_api_common_v1_Payload(fs, o.completed) if o.HasField("failed"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) if o.HasField("cancelled"): - await self._visit_temporal_api_failure_v1_Failure(f, o.cancelled) + await self._visit_temporal_api_failure_v1_Failure(fs, o.cancelled) if o.HasField("timed_out"): - await self._visit_temporal_api_failure_v1_Failure(f, o.timed_out) + await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) - async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, f, o): + async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): if o.HasField("result"): - await self._visit_coresdk_nexus_NexusOperationResult(f, o.result) + await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) - async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, f, o): + async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): if o.HasField("initialize_workflow"): await self._visit_coresdk_workflow_activation_InitializeWorkflow( - f, o.initialize_workflow + fs, o.initialize_workflow ) if o.HasField("query_workflow"): await self._visit_coresdk_workflow_activation_QueryWorkflow( - f, o.query_workflow + fs, o.query_workflow ) if o.HasField("signal_workflow"): await self._visit_coresdk_workflow_activation_SignalWorkflow( - f, o.signal_workflow + fs, o.signal_workflow ) if o.HasField("resolve_activity"): await self._visit_coresdk_workflow_activation_ResolveActivity( - f, o.resolve_activity + fs, o.resolve_activity ) if o.HasField("resolve_child_workflow_execution_start"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( - f, o.resolve_child_workflow_execution_start + fs, o.resolve_child_workflow_execution_start ) if o.HasField("resolve_child_workflow_execution"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( - f, o.resolve_child_workflow_execution + fs, o.resolve_child_workflow_execution ) if o.HasField("resolve_signal_external_workflow"): await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( - f, o.resolve_signal_external_workflow + fs, o.resolve_signal_external_workflow ) if o.HasField("resolve_request_cancel_external_workflow"): await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( - f, o.resolve_request_cancel_external_workflow + fs, o.resolve_request_cancel_external_workflow ) if o.HasField("do_update"): - await self._visit_coresdk_workflow_activation_DoUpdate(f, o.do_update) + await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update) if o.HasField("resolve_nexus_operation_start"): await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( - f, o.resolve_nexus_operation_start + fs, o.resolve_nexus_operation_start ) if o.HasField("resolve_nexus_operation"): await self._visit_coresdk_workflow_activation_ResolveNexusOperation( - f, o.resolve_nexus_operation + fs, o.resolve_nexus_operation ) - async def _visit_coresdk_workflow_activation_WorkflowActivation(self, f, o): + async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): for v in o.jobs: - await self._visit_coresdk_workflow_activation_WorkflowActivationJob(f, v) + await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) - async def _visit_temporal_api_sdk_v1_UserMetadata(self, f, o): + async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): if o.HasField("summary"): - await self._visit_temporal_api_common_v1_Payload(f, o.summary) + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) if o.HasField("details"): - await self._visit_temporal_api_common_v1_Payload(f, o.details) + await self._visit_temporal_api_common_v1_Payload(fs, o.details) - async def _visit_coresdk_workflow_commands_ScheduleActivity(self, f, o): + async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) - for v in o.arguments: - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) - async def _visit_coresdk_workflow_commands_QuerySuccess(self, f, o): + async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): if o.HasField("response"): - await self._visit_temporal_api_common_v1_Payload(f, o.response) + await self._visit_temporal_api_common_v1_Payload(fs, o.response) - async def _visit_coresdk_workflow_commands_QueryResult(self, f, o): + async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): if o.HasField("succeeded"): - await self._visit_coresdk_workflow_commands_QuerySuccess(f, o.succeeded) + await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) if o.HasField("failed"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failed) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, f, o): + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): if o.HasField("result"): - await self._visit_temporal_api_common_v1_Payload(f, o.result) + await self._visit_temporal_api_common_v1_Payload(fs, o.result) - async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, f, o): + async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - self, f, o + self, fs, o ): - for v in o.arguments: - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_payload_container(fs, o.arguments) for v in o.memo.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, f, o): - for v in o.input: - await self._visit_temporal_api_common_v1_Payload(f, v) + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + await self._visit_payload_container(fs, o.input) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.memo.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - self, f, o + self, fs, o ): - for v in o.args: - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_payload_container(fs, o.args) if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, f, o): + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): if not self.skip_headers: for v in o.headers.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) - for v in o.arguments: - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - self, f, o + self, fs, o ): for v in o.search_attributes.values(): - await self._visit_temporal_api_common_v1_Payload(f, v) + await self._visit_temporal_api_common_v1_Payload(fs, v) - async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, f, o): + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): if o.HasField("upserted_memo"): - await self._visit_temporal_api_common_v1_Memo(f, o.upserted_memo) + await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) - async def _visit_coresdk_workflow_commands_UpdateResponse(self, f, o): + async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): if o.HasField("rejected"): - await self._visit_temporal_api_failure_v1_Failure(f, o.rejected) + await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) if o.HasField("completed"): - await self._visit_temporal_api_common_v1_Payload(f, o.completed) + await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, f, o): + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): if o.HasField("input"): - await self._visit_temporal_api_common_v1_Payload(f, o.input) + await self._visit_temporal_api_common_v1_Payload(fs, o.input) - async def _visit_coresdk_workflow_commands_WorkflowCommand(self, f, o): + async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): if o.HasField("user_metadata"): - await self._visit_temporal_api_sdk_v1_UserMetadata(f, o.user_metadata) + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) if o.HasField("schedule_activity"): await self._visit_coresdk_workflow_commands_ScheduleActivity( - f, o.schedule_activity + fs, o.schedule_activity ) if o.HasField("respond_to_query"): await self._visit_coresdk_workflow_commands_QueryResult( - f, o.respond_to_query + fs, o.respond_to_query ) if o.HasField("complete_workflow_execution"): await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( - f, o.complete_workflow_execution + fs, o.complete_workflow_execution ) if o.HasField("fail_workflow_execution"): await self._visit_coresdk_workflow_commands_FailWorkflowExecution( - f, o.fail_workflow_execution + fs, o.fail_workflow_execution ) if o.HasField("continue_as_new_workflow_execution"): await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( - f, o.continue_as_new_workflow_execution + fs, o.continue_as_new_workflow_execution ) if o.HasField("start_child_workflow_execution"): await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( - f, o.start_child_workflow_execution + fs, o.start_child_workflow_execution ) if o.HasField("signal_external_workflow_execution"): await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( - f, o.signal_external_workflow_execution + fs, o.signal_external_workflow_execution ) if o.HasField("schedule_local_activity"): await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( - f, o.schedule_local_activity + fs, o.schedule_local_activity ) if o.HasField("upsert_workflow_search_attributes"): await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( - f, o.upsert_workflow_search_attributes + fs, o.upsert_workflow_search_attributes ) if o.HasField("modify_workflow_properties"): await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( - f, o.modify_workflow_properties + fs, o.modify_workflow_properties ) if o.HasField("update_response"): await self._visit_coresdk_workflow_commands_UpdateResponse( - f, o.update_response + fs, o.update_response ) if o.HasField("schedule_nexus_operation"): await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( - f, o.schedule_nexus_operation + fs, o.schedule_nexus_operation ) - async def _visit_coresdk_workflow_completion_Success(self, f, o): + async def _visit_coresdk_workflow_completion_Success(self, fs, o): for v in o.commands: - await self._visit_coresdk_workflow_commands_WorkflowCommand(f, v) + await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) - async def _visit_coresdk_workflow_completion_Failure(self, f, o): + async def _visit_coresdk_workflow_completion_Failure(self, fs, o): if o.HasField("failure"): - await self._visit_temporal_api_failure_v1_Failure(f, o.failure) + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( - self, f, o + self, fs, o ): if o.HasField("successful"): - await self._visit_coresdk_workflow_completion_Success(f, o.successful) + await self._visit_coresdk_workflow_completion_Success(fs, o.successful) if o.HasField("failed"): - await self._visit_coresdk_workflow_completion_Failure(f, o.failed) + await self._visit_coresdk_workflow_completion_Failure(fs, o.failed) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 6b0b79e3e..6ce87dbd7 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -12,6 +12,7 @@ Callable, List, Mapping, + MutableSequence, Optional, Sequence, Set, @@ -34,8 +35,8 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter import temporalio.exceptions -from temporalio.api.common.v1.message_pb2 import Payload -from temporalio.bridge._visitor import PayloadVisitor +from temporalio.api.common.v1.message_pb2 import Payload, Payloads +from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -277,19 +278,33 @@ async def finalize_shutdown(self) -> None: await ref.finalize_shutdown() +class _Visitor(VisitorFunctions): + def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[List[Payload]]]): + self._f = f + + async def visit_payload(self, payload: Payload) -> None: + new_payload = (await self._f([payload]))[0] + payload.CopyFrom(new_payload) + + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + if len(payloads) == 0: + return + new_payloads = await self._f(payloads) + if new_payloads is payloads: + return + del payloads[:] + payloads.extend(new_payloads) + + async def decode_activation( act: temporalio.bridge.proto.workflow_activation.WorkflowActivation, codec: temporalio.converter.PayloadCodec, decode_headers: bool, ) -> None: """Decode the given activation with the codec.""" - - async def visitor(payload: Payload) -> Payload: - return (await codec.decode([payload]))[0] - await PayloadVisitor( skip_search_attributes=True, skip_headers=not decode_headers - ).visit(visitor, act) + ).visit(_Visitor(codec.decode), act) async def encode_completion( @@ -298,10 +313,6 @@ async def encode_completion( encode_headers: bool, ) -> None: """Recursively encode the given completion with the codec.""" - - async def visitor(payload: Payload) -> Payload: - return (await codec.encode([payload]))[0] - await PayloadVisitor( skip_search_attributes=True, skip_headers=not encode_headers - ).visit(visitor, comp) + ).visit(_Visitor(codec.encode), comp) diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py index b36732f36..c59a0248b 100644 --- a/tests/worker/test_visitor.py +++ b/tests/worker/test_visitor.py @@ -1,5 +1,8 @@ +from typing import MutableSequence + from google.protobuf.duration_pb2 import Duration +import temporalio.bridge.worker from temporalio.api.common.v1.message_pb2 import ( Payload, Payloads, @@ -7,7 +10,7 @@ SearchAttributes, ) from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata -from temporalio.bridge._visitor import PayloadVisitor +from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( InitializeWorkflow, WorkflowActivation, @@ -26,6 +29,16 @@ Success, WorkflowActivationCompletion, ) +from tests.worker.test_workflow import SimpleCodec + + +class Visitor(VisitorFunctions): + async def visit_payload(self, payload: Payload) -> None: + payload.metadata["visited"] = b"True" + + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + for payload in payloads: + payload.metadata["visited"] = b"True" async def test_workflow_activation_completion(): @@ -50,21 +63,14 @@ async def test_workflow_activation_completion(): ), ) - async def visitor(payload: Payload) -> Payload: - # Mark visited by prefixing data - new_payload = Payload() - new_payload.metadata.update(payload.metadata) - new_payload.data = b"visited:" + payload.data - return new_payload - - await PayloadVisitor().visit(visitor, comp) + await PayloadVisitor().visit(Visitor(), comp) cmd = comp.successful.commands[0] sa = cmd.schedule_activity - assert sa.headers["foo"].data == b"visited:bar" - assert len(sa.arguments) == 1 and sa.arguments[0].data == b"visited:baz" + assert sa.headers["foo"].metadata["visited"] + assert len(sa.arguments) == 1 and sa.arguments[0].metadata["visited"] - assert cmd.user_metadata.summary.data == b"visited:Summary" + assert cmd.user_metadata.summary.metadata["visited"] async def test_workflow_activation(): @@ -102,7 +108,7 @@ async def visitor(payload: Payload) -> Payload: return new_payload act = original.__deepcopy__() - await PayloadVisitor().visit(visitor, act) + await PayloadVisitor().visit(Visitor(), act) assert act.jobs[0].initialize_workflow.arguments[0].metadata["visited"] assert act.jobs[0].initialize_workflow.arguments[1].metadata["visited"] assert act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] @@ -123,7 +129,7 @@ async def visitor(payload: Payload) -> Payload: ) act = original.__deepcopy__() - await PayloadVisitor(skip_search_attributes=True).visit(visitor, act) + await PayloadVisitor(skip_search_attributes=True).visit(Visitor(), act) assert ( not act.jobs[0] .initialize_workflow.search_attributes.indexed_fields["sakey"] @@ -131,7 +137,7 @@ async def visitor(payload: Payload) -> Payload: ) act = original.__deepcopy__() - await PayloadVisitor(skip_headers=True).visit(visitor, act) + await PayloadVisitor(skip_headers=True).visit(Visitor(), act) assert not act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] @@ -180,58 +186,62 @@ async def test_visit_payloads_on_other_commands(): ), ) - async def visitor(payload: Payload) -> Payload: - new_payload = Payload() - new_payload.metadata.update(payload.metadata) - new_payload.data = b"visited:" + payload.data - return new_payload - - await PayloadVisitor().visit(visitor, comp) + await PayloadVisitor().visit(Visitor(), comp) cmds = comp.successful.commands can = cmds[0].continue_as_new_workflow_execution - assert can.arguments[0].data == b"visited:a1" - assert can.headers["h1"].data == b"visited:a2" - assert can.memo["m1"].data == b"visited:a3" + assert can.arguments[0].metadata["visited"] + assert can.headers["h1"].metadata["visited"] + assert can.memo["m1"].metadata["visited"] sc = cmds[1].start_child_workflow_execution - assert sc.input[0].data == b"visited:b1" - assert sc.headers["h2"].data == b"visited:b2" - assert sc.memo["m2"].data == b"visited:b3" + assert sc.input[0].metadata["visited"] + assert sc.headers["h2"].metadata["visited"] + assert sc.memo["m2"].metadata["visited"] se = cmds[2].signal_external_workflow_execution - assert se.args[0].data == b"visited:c1" - assert se.headers["h3"].data == b"visited:c2" + assert se.args[0].metadata["visited"] + assert se.headers["h3"].metadata["visited"] sla = cmds[3].schedule_local_activity - assert sla.arguments[0].data == b"visited:d1" - assert sla.headers["h4"].data == b"visited:d2" + assert sla.arguments[0].metadata["visited"] + assert sla.headers["h4"].metadata["visited"] ur = cmds[4].update_response - assert ur.completed.data == b"visited:e1" + assert ur.completed.metadata["visited"] -async def test_code_gen(): - # Smoke test the generated visitor on a simple activation containing payloads - act = WorkflowActivation( - jobs=[ - WorkflowActivationJob( - initialize_workflow=InitializeWorkflow( - arguments=[Payload(data=b"x1"), Payload(data=b"x2")], - headers={"h": Payload(data=b"x3")}, +async def test_bridge_encoding(): + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=1, + activity_id="1", + activity_type="", + task_queue="", + headers={"foo": Payload(data=b"bar")}, + arguments=[ + Payload(data=b"repeated1"), + Payload(data=b"repeated2"), + ], + schedule_to_close_timeout=Duration(seconds=5), + priority=Priority(), + ), + user_metadata=UserMetadata(summary=Payload(data=b"Summary")), ) - ) - ] + ], + ), ) - async def _f(p: Payload) -> Payload: - q = Payload() - q.metadata.update(p.metadata) - q.data = b"v:" + p.data - return q - - await PayloadVisitor().visit(_f, act) - init = act.jobs[0].initialize_workflow - assert init.arguments[0].data == b"v:x1" - assert init.arguments[1].data == b"v:x2" - assert init.headers["h"].data == b"v:x3" + await temporalio.bridge.worker.encode_completion(comp, SimpleCodec(), True) + + cmd = comp.successful.commands[0] + sa = cmd.schedule_activity + assert sa.headers["foo"].metadata["simple-codec"] + assert len(sa.arguments) == 1 + assert sa.arguments[0].metadata["simple-codec"] + + assert cmd.user_metadata.summary.metadata["simple-codec"] From cc83c07f62ae0fbd6c7a00fa592f4b8134a8a5eb Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 3 Sep 2025 11:25:03 -0700 Subject: [PATCH 09/14] Optimize codec visitor --- temporalio/bridge/worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index 6ce87dbd7..9b2abed8e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -284,7 +284,8 @@ def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[List[Payload]]]): async def visit_payload(self, payload: Payload) -> None: new_payload = (await self._f([payload]))[0] - payload.CopyFrom(new_payload) + if new_payload is not payload: + payload.CopyFrom(new_payload) async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: if len(payloads) == 0: From bc1ad6ae5a9921b568954d0e990b023d39fb11e1 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Thu, 4 Sep 2025 08:37:15 -0700 Subject: [PATCH 10/14] Add warning about encode_failure --- temporalio/converter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 190fda0e6..a9f8c0c98 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -689,11 +689,17 @@ async def decode_wrapper(self, payloads: temporalio.api.common.v1.Payloads) -> N payloads.payloads.extend(new_payloads) async def encode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: - """Encode payloads of a failure.""" + """Encode payloads of a failure. Intended as a helper method, not for overriding. + It is not guaranteed that all failures will be encoded with this method rather + than encoding the underlying payloads. + """ await self._apply_to_failure_payloads(failure, self.encode_wrapper) async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: - """Decode payloads of a failure.""" + """Decode payloads of a failure. Intended as a helper method, not for overriding. + It is not guaranteed that all failures will be decoded with this method rather + than decoding the underlying payloads. + """ await self._apply_to_failure_payloads(failure, self.decode_wrapper) async def _apply_to_failure_payloads( From 3c0aa254233acd2aec0782e276cc05bffde66c31 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 9 Sep 2025 10:07:33 -0700 Subject: [PATCH 11/14] Update to use elif on oneof fields --- scripts/gen_visitors.py | 33 +++++++++++++++++ temporalio/bridge/_visitor.py | 68 +++++++++++++++++------------------ 2 files changed, 67 insertions(+), 34 deletions(-) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index 7f8eb6e68..7fd363e8f 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -113,10 +113,25 @@ def walk(self, desc: Descriptor) -> bool: lines.append(" if self.skip_search_attributes:") lines.append(" return") + # Group fields by oneof to generate if/elif chains + oneof_fields: dict[int, list[FieldDescriptor]] = {} + regular_fields: list[FieldDescriptor] = [] + for field in desc.fields: if field.type != FieldDescriptor.TYPE_MESSAGE: continue + + # Skip synthetic oneofs (proto3 optional fields) + if field.containing_oneof is not None: + oneof_idx = field.containing_oneof.index + if oneof_idx not in oneof_fields: + oneof_fields[oneof_idx] = [] + oneof_fields[oneof_idx].append(field) + else: + regular_fields.append(field) + # Process regular fields first + for field in regular_fields: # Repeated fields (including maps which are represented as repeated messages) if field.label == FieldDescriptor.LABEL_REPEATED: if ( @@ -174,6 +189,24 @@ def walk(self, desc: Descriptor) -> bool: ) ) + # Process oneof fields as if/elif chains + for oneof_idx, fields in oneof_fields.items(): + oneof_lines = [] + first = True + for field in fields: + child_desc = field.message_type + child_needed = self.walk(child_desc) + needed |= child_needed + if child_needed: + if_word = "if" if first else "elif" + first = False + line = emit_singular( + field.name, f"o.{field.name}", name_for(child_desc), True + ).replace(" if", f" {if_word}", 1) + oneof_lines.append(line) + if oneof_lines: + lines.extend(oneof_lines) + self.generated[key] = needed self.in_progress.discard(key) if needed: diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index e99087ecf..2db307f2a 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -73,15 +73,15 @@ async def _visit_temporal_api_failure_v1_Failure(self, fs, o): await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( fs, o.application_failure_info ) - if o.HasField("timeout_failure_info"): + elif o.HasField("timeout_failure_info"): await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( fs, o.timeout_failure_info ) - if o.HasField("canceled_failure_info"): + elif o.HasField("canceled_failure_info"): await self._visit_temporal_api_failure_v1_CanceledFailureInfo( fs, o.canceled_failure_info ) - if o.HasField("reset_workflow_failure_info"): + elif o.HasField("reset_workflow_failure_info"): await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( fs, o.reset_workflow_failure_info ) @@ -141,9 +141,9 @@ async def _visit_coresdk_activity_result_Cancellation(self, fs, o): async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): if o.HasField("completed"): await self._visit_coresdk_activity_result_Success(fs, o.completed) - if o.HasField("failed"): + elif o.HasField("failed"): await self._visit_coresdk_activity_result_Failure(fs, o.failed) - if o.HasField("cancelled"): + elif o.HasField("cancelled"): await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): @@ -179,9 +179,9 @@ async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): if o.HasField("completed"): await self._visit_coresdk_child_workflow_Success(fs, o.completed) - if o.HasField("failed"): + elif o.HasField("failed"): await self._visit_coresdk_child_workflow_Failure(fs, o.failed) - if o.HasField("cancelled"): + elif o.HasField("cancelled"): await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( @@ -217,11 +217,11 @@ async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): if o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) - if o.HasField("failed"): + elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) - if o.HasField("cancelled"): + elif o.HasField("cancelled"): await self._visit_temporal_api_failure_v1_Failure(fs, o.cancelled) - if o.HasField("timed_out"): + elif o.HasField("timed_out"): await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): @@ -233,41 +233,41 @@ async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): await self._visit_coresdk_workflow_activation_InitializeWorkflow( fs, o.initialize_workflow ) - if o.HasField("query_workflow"): + elif o.HasField("query_workflow"): await self._visit_coresdk_workflow_activation_QueryWorkflow( fs, o.query_workflow ) - if o.HasField("signal_workflow"): + elif o.HasField("signal_workflow"): await self._visit_coresdk_workflow_activation_SignalWorkflow( fs, o.signal_workflow ) - if o.HasField("resolve_activity"): + elif o.HasField("resolve_activity"): await self._visit_coresdk_workflow_activation_ResolveActivity( fs, o.resolve_activity ) - if o.HasField("resolve_child_workflow_execution_start"): + elif o.HasField("resolve_child_workflow_execution_start"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( fs, o.resolve_child_workflow_execution_start ) - if o.HasField("resolve_child_workflow_execution"): + elif o.HasField("resolve_child_workflow_execution"): await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( fs, o.resolve_child_workflow_execution ) - if o.HasField("resolve_signal_external_workflow"): + elif o.HasField("resolve_signal_external_workflow"): await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( fs, o.resolve_signal_external_workflow ) - if o.HasField("resolve_request_cancel_external_workflow"): + elif o.HasField("resolve_request_cancel_external_workflow"): await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( fs, o.resolve_request_cancel_external_workflow ) - if o.HasField("do_update"): + elif o.HasField("do_update"): await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update) - if o.HasField("resolve_nexus_operation_start"): + elif o.HasField("resolve_nexus_operation_start"): await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( fs, o.resolve_nexus_operation_start ) - if o.HasField("resolve_nexus_operation"): + elif o.HasField("resolve_nexus_operation"): await self._visit_coresdk_workflow_activation_ResolveNexusOperation( fs, o.resolve_nexus_operation ) @@ -295,7 +295,7 @@ async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): if o.HasField("succeeded"): await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) - if o.HasField("failed"): + elif o.HasField("failed"): await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): @@ -355,7 +355,7 @@ async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o) async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): if o.HasField("rejected"): await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) - if o.HasField("completed"): + elif o.HasField("completed"): await self._visit_temporal_api_common_v1_Payload(fs, o.completed) async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): @@ -369,47 +369,47 @@ async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): await self._visit_coresdk_workflow_commands_ScheduleActivity( fs, o.schedule_activity ) - if o.HasField("respond_to_query"): + elif o.HasField("respond_to_query"): await self._visit_coresdk_workflow_commands_QueryResult( fs, o.respond_to_query ) - if o.HasField("complete_workflow_execution"): + elif o.HasField("complete_workflow_execution"): await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( fs, o.complete_workflow_execution ) - if o.HasField("fail_workflow_execution"): + elif o.HasField("fail_workflow_execution"): await self._visit_coresdk_workflow_commands_FailWorkflowExecution( fs, o.fail_workflow_execution ) - if o.HasField("continue_as_new_workflow_execution"): + elif o.HasField("continue_as_new_workflow_execution"): await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( fs, o.continue_as_new_workflow_execution ) - if o.HasField("start_child_workflow_execution"): + elif o.HasField("start_child_workflow_execution"): await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( fs, o.start_child_workflow_execution ) - if o.HasField("signal_external_workflow_execution"): + elif o.HasField("signal_external_workflow_execution"): await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( fs, o.signal_external_workflow_execution ) - if o.HasField("schedule_local_activity"): + elif o.HasField("schedule_local_activity"): await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( fs, o.schedule_local_activity ) - if o.HasField("upsert_workflow_search_attributes"): + elif o.HasField("upsert_workflow_search_attributes"): await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( fs, o.upsert_workflow_search_attributes ) - if o.HasField("modify_workflow_properties"): + elif o.HasField("modify_workflow_properties"): await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( fs, o.modify_workflow_properties ) - if o.HasField("update_response"): + elif o.HasField("update_response"): await self._visit_coresdk_workflow_commands_UpdateResponse( fs, o.update_response ) - if o.HasField("schedule_nexus_operation"): + elif o.HasField("schedule_nexus_operation"): await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( fs, o.schedule_nexus_operation ) @@ -427,5 +427,5 @@ async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( ): if o.HasField("successful"): await self._visit_coresdk_workflow_completion_Success(fs, o.successful) - if o.HasField("failed"): + elif o.HasField("failed"): await self._visit_coresdk_workflow_completion_Failure(fs, o.failed) From 0f7d0001fa8474ea18a362a7c7bd3092551e6dd9 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 9 Sep 2025 10:16:37 -0700 Subject: [PATCH 12/14] Generate visitors during proto generation --- scripts/gen_protos_docker.py | 2 ++ scripts/gen_visitors.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/gen_protos_docker.py b/scripts/gen_protos_docker.py index 7014022bc..e7738259e 100644 --- a/scripts/gen_protos_docker.py +++ b/scripts/gen_protos_docker.py @@ -24,3 +24,5 @@ check=True, ) subprocess.run(["uv", "run", "poe", "format"], check=True) + +subprocess.run(["uv", "run", f"{os.getcwd()}/scripts/gen_visitors.py"], check=True) diff --git a/scripts/gen_visitors.py b/scripts/gen_visitors.py index 7fd363e8f..d8498df32 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_visitors.py @@ -116,11 +116,11 @@ def walk(self, desc: Descriptor) -> bool: # Group fields by oneof to generate if/elif chains oneof_fields: dict[int, list[FieldDescriptor]] = {} regular_fields: list[FieldDescriptor] = [] - + for field in desc.fields: if field.type != FieldDescriptor.TYPE_MESSAGE: continue - + # Skip synthetic oneofs (proto3 optional fields) if field.containing_oneof is not None: oneof_idx = field.containing_oneof.index From f8398eeec0af97e31bf0a2a8d49afa80ffb236cb Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 10 Sep 2025 08:48:29 -0700 Subject: [PATCH 13/14] Use os path join in script --- scripts/gen_protos_docker.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/gen_protos_docker.py b/scripts/gen_protos_docker.py index e7738259e..6e17be564 100644 --- a/scripts/gen_protos_docker.py +++ b/scripts/gen_protos_docker.py @@ -3,7 +3,14 @@ # Build the Docker image and capture its ID result = subprocess.run( - ["docker", "build", "-q", "-f", "scripts/_proto/Dockerfile", "."], + [ + "docker", + "build", + "-q", + "-f", + os.path.join("scripts", "_proto", "Dockerfile"), + ".", + ], capture_output=True, text=True, check=True, @@ -16,13 +23,15 @@ "run", "--rm", "-v", - f"{os.getcwd()}/temporalio/api:/api_new", + os.path.join(os.getcwd(), "temporalio", "api") + ":/api_new", "-v", - f"{os.getcwd()}/temporalio/bridge/proto:/bridge_new", + os.path.join(os.getcwd(), "temporalio", "bridge", "proto") + ":/bridge_new", image_id, ], check=True, ) subprocess.run(["uv", "run", "poe", "format"], check=True) -subprocess.run(["uv", "run", f"{os.getcwd()}/scripts/gen_visitors.py"], check=True) +subprocess.run( + ["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_visitors.py")], check=True +) From 1a9a2307b484da505881f114f69cd84f39b1edf8 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 10 Sep 2025 11:05:11 -0700 Subject: [PATCH 14/14] PR comments --- ...gen_visitors.py => gen_payload_visitor.py} | 176 +++++++++--------- scripts/gen_protos_docker.py | 3 +- temporalio/bridge/_visitor.py | 9 +- 3 files changed, 100 insertions(+), 88 deletions(-) rename scripts/{gen_visitors.py => gen_payload_visitor.py} (86%) diff --git a/scripts/gen_visitors.py b/scripts/gen_payload_visitor.py similarity index 86% rename from scripts/gen_visitors.py rename to scripts/gen_payload_visitor.py index d8498df32..f35f41d71 100644 --- a/scripts/gen_visitors.py +++ b/scripts/gen_payload_visitor.py @@ -39,18 +39,18 @@ def emit_loop( def emit_singular( - field_name: str, access_expr: str, child_method: str, check_presence: bool + field_name: str, access_expr: str, child_method: str, presence_word: Optional[str] ) -> str: # Helper to emit a singular field visit with presence check and optional headers guard - if check_presence: + if presence_word: if field_name == "headers": return f"""\ if not self.skip_headers: - if o.HasField("{field_name}"): + {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" else: return f"""\ - if o.HasField("{field_name}"): + {presence_word} o.HasField("{field_name}"): await self._visit_{child_method}(fs, {access_expr})""" else: if field_name == "headers": @@ -63,6 +63,67 @@ def emit_singular( class VisitorGenerator: + def generate(self, roots: list[Descriptor]) -> str: + """ + Generate Python source code that, given a function f(Payload) -> Payload, + applies it to every Payload contained within a WorkflowActivation tree. + + The generated code defines async visitor functions for each reachable + protobuf message type starting from WorkflowActivation, including support + for repeated fields and map entries, and a convenience entrypoint + function `visit`. + """ + + for r in roots: + self.walk(r) + + header = """ +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import abc +from typing import Any, MutableSequence + +from temporalio.api.common.v1.message_pb2 import Payload + +class VisitorFunctions(abc.ABC): + \"\"\"Set of functions which can be called by the visitor. + Allows handling payloads as a sequence. + \"\"\" + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + \"\"\"Called when encountering a single payload.\"\"\" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + \"\"\"Called when encountering multiple payloads together.\"\"\" + raise NotImplementedError() + +class PayloadVisitor: + \"\"\"A visitor for payloads. + Applies a function to every payload in a tree of messages. + \"\"\" + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + \"\"\"Creates a new payload visitor.\"\"\" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit( + self, fs: VisitorFunctions, root: Any + ) -> None: + \"\"\"Visits the given root message with the given function.\"\"\" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(fs, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + +""" + + return header + "\n".join(self.methods) + def __init__(self): # Track which message descriptors have visitor methods generated self.generated: dict[str, bool] = { @@ -71,21 +132,24 @@ def __init__(self): } self.in_progress: set[str] = set() self.methods: list[str] = [ - """ async def _visit_temporal_api_common_v1_Payload(self, fs, o): - await fs.visit_payload(o) + """\ + async def _visit_temporal_api_common_v1_Payload(self, fs, o): + await fs.visit_payload(o) """, - """ async def _visit_temporal_api_common_v1_Payloads(self, fs, o): - await fs.visit_payloads(o.payloads) + """\ + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + await fs.visit_payloads(o.payloads) """, - """ async def _visit_payload_container(self, fs, o): - await fs.visit_payloads(o) + """\ + async def _visit_payload_container(self, fs, o): + await fs.visit_payloads(o) """, ] def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]: # Special case for repeated payloads, handle them directly if child_desc.full_name == Payload.DESCRIPTOR.full_name: - return emit_singular(field.name, iter_expr, "payload_container", False) + return emit_singular(field.name, iter_expr, "payload_container", None) else: child_needed = self.walk(child_desc) if child_needed: @@ -105,7 +169,7 @@ def walk(self, desc: Descriptor) -> bool: # Break cycles; if another path proves this node needed, we'll revisit return False - needed = False + has_payload = False self.in_progress.add(key) lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] # If this is the SearchAttributes message, allow skipping @@ -146,7 +210,7 @@ def walk(self, desc: Descriptor) -> bool: child_desc = val_fd.message_type child_needed = self.walk(child_desc) if child_needed: - needed = True + has_payload = True lines.append( emit_loop( field.name, @@ -163,7 +227,7 @@ def walk(self, desc: Descriptor) -> bool: child_desc = key_fd.message_type child_needed = self.walk(child_desc) if child_needed: - needed = True + has_payload = True lines.append( emit_loop( field.name, @@ -176,16 +240,16 @@ def walk(self, desc: Descriptor) -> bool: field.message_type, field, f"o.{field.name}" ) if child is not None: - needed = True + has_payload = True lines.append(child) else: child_desc = field.message_type - child_needed = self.walk(child_desc) - needed |= child_needed - if child_needed: + child_has_payload = self.walk(child_desc) + has_payload |= child_has_payload + if child_has_payload: lines.append( emit_singular( - field.name, f"o.{field.name}", name_for(child_desc), True + field.name, f"o.{field.name}", name_for(child_desc), "if" ) ) @@ -195,81 +259,23 @@ def walk(self, desc: Descriptor) -> bool: first = True for field in fields: child_desc = field.message_type - child_needed = self.walk(child_desc) - needed |= child_needed - if child_needed: + child_has_payload = self.walk(child_desc) + has_payload |= child_has_payload + if child_has_payload: if_word = "if" if first else "elif" first = False line = emit_singular( - field.name, f"o.{field.name}", name_for(child_desc), True - ).replace(" if", f" {if_word}", 1) + field.name, f"o.{field.name}", name_for(child_desc), if_word + ) oneof_lines.append(line) if oneof_lines: lines.extend(oneof_lines) - self.generated[key] = needed + self.generated[key] = has_payload self.in_progress.discard(key) - if needed: + if has_payload: self.methods.append("\n".join(lines) + "\n") - return needed - - def generate(self, roots: list[Descriptor]) -> str: - """ - Generate Python source code that, given a function f(Payload) -> Payload, - applies it to every Payload contained within a WorkflowActivation tree. - - The generated code defines async visitor functions for each reachable - protobuf message type starting from WorkflowActivation, including support - for repeated fields and map entries, and a convenience entrypoint - function `visit`. - """ - - # We avoid importing google.api deps in service protos; expand by walking from - # WorkflowActivationCompletion root which references many command messages. - for r in roots: - self.walk(r) - - header = """ -import abc -from typing import Any, MutableSequence - -from temporalio.api.common.v1.message_pb2 import Payload - -class VisitorFunctions(abc.ABC): - \"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\"\"\" - @abc.abstractmethod - async def visit_payload(self, payload: Payload) -> None: - \"\"\"Called when encountering a single payload.\"\"\" - raise NotImplementedError() - - @abc.abstractmethod - async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: - \"\"\"Called when encountering multiple payloads together.\"\"\" - raise NotImplementedError() - -class PayloadVisitor: - \"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\" - def __init__( - self, *, skip_search_attributes: bool = False, skip_headers: bool = False - ): - \"\"\"Creates a new payload visitor.\"\"\" - self.skip_search_attributes = skip_search_attributes - self.skip_headers = skip_headers - - async def visit( - self, fs: VisitorFunctions, root: Any - ) -> None: - \"\"\"Visits the given root message with the given function.\"\"\" - method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") - method = getattr(self, method_name, None) - if method is not None: - await method(fs, root) - else: - raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") - -""" - - return header + "\n".join(self.methods) + return has_payload def write_generated_visitors_into_visitor_generated_py() -> None: diff --git a/scripts/gen_protos_docker.py b/scripts/gen_protos_docker.py index 6e17be564..099c56a2d 100644 --- a/scripts/gen_protos_docker.py +++ b/scripts/gen_protos_docker.py @@ -33,5 +33,6 @@ subprocess.run(["uv", "run", "poe", "format"], check=True) subprocess.run( - ["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_visitors.py")], check=True + ["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_payload_visitor.py")], + check=True, ) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py index 2db307f2a..c7e38af37 100644 --- a/temporalio/bridge/_visitor.py +++ b/temporalio/bridge/_visitor.py @@ -1,3 +1,4 @@ +# This file is generated by gen_payload_visitor.py. Changes should be made there. import abc from typing import Any, MutableSequence @@ -5,7 +6,9 @@ class VisitorFunctions(abc.ABC): - """Set of functions which can be called by the visitor. Allows handling payloads as a sequence.""" + """Set of functions which can be called by the visitor. + Allows handling payloads as a sequence. + """ @abc.abstractmethod async def visit_payload(self, payload: Payload) -> None: @@ -19,7 +22,9 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: class PayloadVisitor: - """A visitor for payloads. Applies a function to every payload in a tree of messages.""" + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ def __init__( self, *, skip_search_attributes: bool = False, skip_headers: bool = False