diff --git a/scripts/gen_payload_visitor.py b/scripts/gen_payload_visitor.py new file mode 100644 index 000000000..f35f41d71 --- /dev/null +++ b/scripts/gen_payload_visitor.py @@ -0,0 +1,299 @@ +import subprocess +import sys +from pathlib import Path +from typing import Optional, Tuple + +from google.protobuf.descriptor import Descriptor, FieldDescriptor + +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + WorkflowActivation, +) +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( + WorkflowActivationCompletion, +) + +base_dir = Path(__file__).parent.parent + + +def name_for(desc: Descriptor) -> str: + # Use fully-qualified name to avoid collisions; replace dots with underscores + return desc.full_name.replace(".", "_") + + +def emit_loop( + field_name: str, + iter_expr: str, + child_method: str, +) -> str: + # Helper to emit a for-loop over a collection with optional headers guard + if field_name == "headers": + return f"""\ + if not self.skip_headers: + for v in {iter_expr}: + await self._visit_{child_method}(fs, v)""" + else: + return f"""\ + for v in {iter_expr}: + await self._visit_{child_method}(fs, v)""" + + +def emit_singular( + field_name: str, access_expr: str, child_method: str, presence_word: Optional[str] +) -> str: + # Helper to emit a singular field visit with presence check and optional headers guard + if presence_word: + if field_name == "headers": + return f"""\ + if not self.skip_headers: + {presence_word} o.HasField("{field_name}"): + await self._visit_{child_method}(fs, {access_expr})""" + else: + return f"""\ + {presence_word} o.HasField("{field_name}"): + await self._visit_{child_method}(fs, {access_expr})""" + else: + if field_name == "headers": + return f"""\ + if not self.skip_headers: + await self._visit_{child_method}(fs, {access_expr})""" + else: + return f"""\ + await self._visit_{child_method}(fs, {access_expr})""" + + +class VisitorGenerator: + def generate(self, roots: list[Descriptor]) -> str: + """ + Generate Python source code that, given a function f(Payload) -> Payload, + applies it to every Payload contained within a WorkflowActivation tree. + + The generated code defines async visitor functions for each reachable + protobuf message type starting from WorkflowActivation, including support + for repeated fields and map entries, and a convenience entrypoint + function `visit`. + """ + + for r in roots: + self.walk(r) + + header = """ +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import abc +from typing import Any, MutableSequence + +from temporalio.api.common.v1.message_pb2 import Payload + +class VisitorFunctions(abc.ABC): + \"\"\"Set of functions which can be called by the visitor. + Allows handling payloads as a sequence. + \"\"\" + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + \"\"\"Called when encountering a single payload.\"\"\" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + \"\"\"Called when encountering multiple payloads together.\"\"\" + raise NotImplementedError() + +class PayloadVisitor: + \"\"\"A visitor for payloads. + Applies a function to every payload in a tree of messages. + \"\"\" + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + \"\"\"Creates a new payload visitor.\"\"\" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit( + self, fs: VisitorFunctions, root: Any + ) -> None: + \"\"\"Visits the given root message with the given function.\"\"\" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(fs, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + +""" + + return header + "\n".join(self.methods) + + def __init__(self): + # Track which message descriptors have visitor methods generated + self.generated: dict[str, bool] = { + Payload.DESCRIPTOR.full_name: True, + Payloads.DESCRIPTOR.full_name: True, + } + self.in_progress: set[str] = set() + self.methods: list[str] = [ + """\ + async def _visit_temporal_api_common_v1_Payload(self, fs, o): + await fs.visit_payload(o) + """, + """\ + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + await fs.visit_payloads(o.payloads) + """, + """\ + async def _visit_payload_container(self, fs, o): + await fs.visit_payloads(o) + """, + ] + + def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]: + # Special case for repeated payloads, handle them directly + if child_desc.full_name == Payload.DESCRIPTOR.full_name: + return emit_singular(field.name, iter_expr, "payload_container", None) + else: + child_needed = self.walk(child_desc) + if child_needed: + return emit_loop( + field.name, + iter_expr, + name_for(child_desc), + ) + else: + return None + + def walk(self, desc: Descriptor) -> bool: + key = desc.full_name + if key in self.generated: + return self.generated[key] + if key in self.in_progress: + # Break cycles; if another path proves this node needed, we'll revisit + return False + + has_payload = False + self.in_progress.add(key) + lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"] + # If this is the SearchAttributes message, allow skipping + if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: + lines.append(" if self.skip_search_attributes:") + lines.append(" return") + + # Group fields by oneof to generate if/elif chains + oneof_fields: dict[int, list[FieldDescriptor]] = {} + regular_fields: list[FieldDescriptor] = [] + + for field in desc.fields: + if field.type != FieldDescriptor.TYPE_MESSAGE: + continue + + # Skip synthetic oneofs (proto3 optional fields) + if field.containing_oneof is not None: + oneof_idx = field.containing_oneof.index + if oneof_idx not in oneof_fields: + oneof_fields[oneof_idx] = [] + oneof_fields[oneof_idx].append(field) + else: + regular_fields.append(field) + + # Process regular fields first + for field in regular_fields: + # Repeated fields (including maps which are represented as repeated messages) + if field.label == FieldDescriptor.LABEL_REPEATED: + if ( + field.message_type is not None + and field.message_type.GetOptions().map_entry + ): + val_fd = field.message_type.fields_by_name.get("value") + if ( + val_fd is not None + and val_fd.type == FieldDescriptor.TYPE_MESSAGE + ): + child_desc = val_fd.message_type + child_needed = self.walk(child_desc) + if child_needed: + has_payload = True + lines.append( + emit_loop( + field.name, + f"o.{field.name}.values()", + name_for(child_desc), + ) + ) + + key_fd = field.message_type.fields_by_name.get("key") + if ( + key_fd is not None + and key_fd.type == FieldDescriptor.TYPE_MESSAGE + ): + child_desc = key_fd.message_type + child_needed = self.walk(child_desc) + if child_needed: + has_payload = True + lines.append( + emit_loop( + field.name, + f"o.{field.name}.keys()", + name_for(child_desc), + ) + ) + else: + child = self.check_repeated( + field.message_type, field, f"o.{field.name}" + ) + if child is not None: + has_payload = True + lines.append(child) + else: + child_desc = field.message_type + child_has_payload = self.walk(child_desc) + has_payload |= child_has_payload + if child_has_payload: + lines.append( + emit_singular( + field.name, f"o.{field.name}", name_for(child_desc), "if" + ) + ) + + # Process oneof fields as if/elif chains + for oneof_idx, fields in oneof_fields.items(): + oneof_lines = [] + first = True + for field in fields: + child_desc = field.message_type + child_has_payload = self.walk(child_desc) + has_payload |= child_has_payload + if child_has_payload: + if_word = "if" if first else "elif" + first = False + line = emit_singular( + field.name, f"o.{field.name}", name_for(child_desc), if_word + ) + oneof_lines.append(line) + if oneof_lines: + lines.extend(oneof_lines) + + self.generated[key] = has_payload + self.in_progress.discard(key) + if has_payload: + self.methods.append("\n".join(lines) + "\n") + return has_payload + + +def write_generated_visitors_into_visitor_generated_py() -> None: + """Write the generated visitor code into _visitor.py.""" + out_path = base_dir / "temporalio" / "bridge" / "_visitor.py" + + # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, + # and all messages from selected API modules + roots: list[Descriptor] = [ + WorkflowActivation.DESCRIPTOR, + WorkflowActivationCompletion.DESCRIPTOR, + ] + + code = VisitorGenerator().generate(roots) + out_path.write_text(code) + + +if __name__ == "__main__": + print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr) + write_generated_visitors_into_visitor_generated_py() + subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"]) diff --git a/scripts/gen_protos.py b/scripts/gen_protos.py index 61d2709fe..32c5c3af2 100644 --- a/scripts/gen_protos.py +++ b/scripts/gen_protos.py @@ -6,7 +6,7 @@ import tempfile from functools import partial from pathlib import Path -from typing import List, Mapping, Optional +from typing import List, Mapping base_dir = Path(__file__).parent.parent proto_dir = ( diff --git a/scripts/gen_protos_docker.py b/scripts/gen_protos_docker.py index 7014022bc..099c56a2d 100644 --- a/scripts/gen_protos_docker.py +++ b/scripts/gen_protos_docker.py @@ -3,7 +3,14 @@ # Build the Docker image and capture its ID result = subprocess.run( - ["docker", "build", "-q", "-f", "scripts/_proto/Dockerfile", "."], + [ + "docker", + "build", + "-q", + "-f", + os.path.join("scripts", "_proto", "Dockerfile"), + ".", + ], capture_output=True, text=True, check=True, @@ -16,11 +23,16 @@ "run", "--rm", "-v", - f"{os.getcwd()}/temporalio/api:/api_new", + os.path.join(os.getcwd(), "temporalio", "api") + ":/api_new", "-v", - f"{os.getcwd()}/temporalio/bridge/proto:/bridge_new", + os.path.join(os.getcwd(), "temporalio", "bridge", "proto") + ":/bridge_new", image_id, ], check=True, ) subprocess.run(["uv", "run", "poe", "format"], check=True) + +subprocess.run( + ["uv", "run", os.path.join(os.getcwd(), "scripts", "gen_payload_visitor.py")], + check=True, +) diff --git a/temporalio/bridge/_visitor.py b/temporalio/bridge/_visitor.py new file mode 100644 index 000000000..c7e38af37 --- /dev/null +++ b/temporalio/bridge/_visitor.py @@ -0,0 +1,436 @@ +# This file is generated by gen_payload_visitor.py. Changes should be made there. +import abc +from typing import Any, MutableSequence + +from temporalio.api.common.v1.message_pb2 import Payload + + +class VisitorFunctions(abc.ABC): + """Set of functions which can be called by the visitor. + Allows handling payloads as a sequence. + """ + + @abc.abstractmethod + async def visit_payload(self, payload: Payload) -> None: + """Called when encountering a single payload.""" + raise NotImplementedError() + + @abc.abstractmethod + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + """Called when encountering multiple payloads together.""" + raise NotImplementedError() + + +class PayloadVisitor: + """A visitor for payloads. + Applies a function to every payload in a tree of messages. + """ + + def __init__( + self, *, skip_search_attributes: bool = False, skip_headers: bool = False + ): + """Creates a new payload visitor.""" + self.skip_search_attributes = skip_search_attributes + self.skip_headers = skip_headers + + async def visit(self, fs: VisitorFunctions, root: Any) -> None: + """Visits the given root message with the given function.""" + method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_") + method = getattr(self, method_name, None) + if method is not None: + await method(fs, root) + else: + raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}") + + async def _visit_temporal_api_common_v1_Payload(self, fs, o): + await fs.visit_payload(o) + + async def _visit_temporal_api_common_v1_Payloads(self, fs, o): + await fs.visit_payloads(o.payloads) + + async def _visit_payload_container(self, fs, o): + await fs.visit_payloads(o) + + async def _visit_temporal_api_failure_v1_ApplicationFailureInfo(self, fs, o): + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.details) + + async def _visit_temporal_api_failure_v1_TimeoutFailureInfo(self, fs, o): + if o.HasField("last_heartbeat_details"): + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_heartbeat_details + ) + + async def _visit_temporal_api_failure_v1_CanceledFailureInfo(self, fs, o): + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payloads(fs, o.details) + + async def _visit_temporal_api_failure_v1_ResetWorkflowFailureInfo(self, fs, o): + if o.HasField("last_heartbeat_details"): + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_heartbeat_details + ) + + async def _visit_temporal_api_failure_v1_Failure(self, fs, o): + if o.HasField("encoded_attributes"): + await self._visit_temporal_api_common_v1_Payload(fs, o.encoded_attributes) + if o.HasField("application_failure_info"): + await self._visit_temporal_api_failure_v1_ApplicationFailureInfo( + fs, o.application_failure_info + ) + elif o.HasField("timeout_failure_info"): + await self._visit_temporal_api_failure_v1_TimeoutFailureInfo( + fs, o.timeout_failure_info + ) + elif o.HasField("canceled_failure_info"): + await self._visit_temporal_api_failure_v1_CanceledFailureInfo( + fs, o.canceled_failure_info + ) + elif o.HasField("reset_workflow_failure_info"): + await self._visit_temporal_api_failure_v1_ResetWorkflowFailureInfo( + fs, o.reset_workflow_failure_info + ) + + async def _visit_temporal_api_common_v1_Memo(self, fs, o): + for v in o.fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_temporal_api_common_v1_SearchAttributes(self, fs, o): + if self.skip_search_attributes: + return + for v in o.indexed_fields.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_activation_InitializeWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.arguments) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + if o.HasField("continued_failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.continued_failure) + if o.HasField("last_completion_result"): + await self._visit_temporal_api_common_v1_Payloads( + fs, o.last_completion_result + ) + if o.HasField("memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.memo) + if o.HasField("search_attributes"): + await self._visit_temporal_api_common_v1_SearchAttributes( + fs, o.search_attributes + ) + + async def _visit_coresdk_workflow_activation_QueryWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.arguments) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_activation_SignalWorkflow(self, fs, o): + await self._visit_payload_container(fs, o.input) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_activity_result_Success(self, fs, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(fs, o.result) + + async def _visit_coresdk_activity_result_Failure(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_activity_result_Cancellation(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_activity_result_ActivityResolution(self, fs, o): + if o.HasField("completed"): + await self._visit_coresdk_activity_result_Success(fs, o.completed) + elif o.HasField("failed"): + await self._visit_coresdk_activity_result_Failure(fs, o.failed) + elif o.HasField("cancelled"): + await self._visit_coresdk_activity_result_Cancellation(fs, o.cancelled) + + async def _visit_coresdk_workflow_activation_ResolveActivity(self, fs, o): + if o.HasField("result"): + await self._visit_coresdk_activity_result_ActivityResolution(fs, o.result) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + self, fs, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + self, fs, o + ): + if o.HasField("cancelled"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStartCancelled( + fs, o.cancelled + ) + + async def _visit_coresdk_child_workflow_Success(self, fs, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(fs, o.result) + + async def _visit_coresdk_child_workflow_Failure(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_child_workflow_Cancellation(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_child_workflow_ChildWorkflowResult(self, fs, o): + if o.HasField("completed"): + await self._visit_coresdk_child_workflow_Success(fs, o.completed) + elif o.HasField("failed"): + await self._visit_coresdk_child_workflow_Failure(fs, o.failed) + elif o.HasField("cancelled"): + await self._visit_coresdk_child_workflow_Cancellation(fs, o.cancelled) + + async def _visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + self, fs, o + ): + if o.HasField("result"): + await self._visit_coresdk_child_workflow_ChildWorkflowResult(fs, o.result) + + async def _visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + self, fs, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + self, fs, o + ): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_workflow_activation_DoUpdate(self, fs, o): + await self._visit_payload_container(fs, o.input) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperationStart( + self, fs, o + ): + if o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) + + async def _visit_coresdk_nexus_NexusOperationResult(self, fs, o): + if o.HasField("completed"): + await self._visit_temporal_api_common_v1_Payload(fs, o.completed) + elif o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) + elif o.HasField("cancelled"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.cancelled) + elif o.HasField("timed_out"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.timed_out) + + async def _visit_coresdk_workflow_activation_ResolveNexusOperation(self, fs, o): + if o.HasField("result"): + await self._visit_coresdk_nexus_NexusOperationResult(fs, o.result) + + async def _visit_coresdk_workflow_activation_WorkflowActivationJob(self, fs, o): + if o.HasField("initialize_workflow"): + await self._visit_coresdk_workflow_activation_InitializeWorkflow( + fs, o.initialize_workflow + ) + elif o.HasField("query_workflow"): + await self._visit_coresdk_workflow_activation_QueryWorkflow( + fs, o.query_workflow + ) + elif o.HasField("signal_workflow"): + await self._visit_coresdk_workflow_activation_SignalWorkflow( + fs, o.signal_workflow + ) + elif o.HasField("resolve_activity"): + await self._visit_coresdk_workflow_activation_ResolveActivity( + fs, o.resolve_activity + ) + elif o.HasField("resolve_child_workflow_execution_start"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecutionStart( + fs, o.resolve_child_workflow_execution_start + ) + elif o.HasField("resolve_child_workflow_execution"): + await self._visit_coresdk_workflow_activation_ResolveChildWorkflowExecution( + fs, o.resolve_child_workflow_execution + ) + elif o.HasField("resolve_signal_external_workflow"): + await self._visit_coresdk_workflow_activation_ResolveSignalExternalWorkflow( + fs, o.resolve_signal_external_workflow + ) + elif o.HasField("resolve_request_cancel_external_workflow"): + await self._visit_coresdk_workflow_activation_ResolveRequestCancelExternalWorkflow( + fs, o.resolve_request_cancel_external_workflow + ) + elif o.HasField("do_update"): + await self._visit_coresdk_workflow_activation_DoUpdate(fs, o.do_update) + elif o.HasField("resolve_nexus_operation_start"): + await self._visit_coresdk_workflow_activation_ResolveNexusOperationStart( + fs, o.resolve_nexus_operation_start + ) + elif o.HasField("resolve_nexus_operation"): + await self._visit_coresdk_workflow_activation_ResolveNexusOperation( + fs, o.resolve_nexus_operation + ) + + async def _visit_coresdk_workflow_activation_WorkflowActivation(self, fs, o): + for v in o.jobs: + await self._visit_coresdk_workflow_activation_WorkflowActivationJob(fs, v) + + async def _visit_temporal_api_sdk_v1_UserMetadata(self, fs, o): + if o.HasField("summary"): + await self._visit_temporal_api_common_v1_Payload(fs, o.summary) + if o.HasField("details"): + await self._visit_temporal_api_common_v1_Payload(fs, o.details) + + async def _visit_coresdk_workflow_commands_ScheduleActivity(self, fs, o): + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) + + async def _visit_coresdk_workflow_commands_QuerySuccess(self, fs, o): + if o.HasField("response"): + await self._visit_temporal_api_common_v1_Payload(fs, o.response) + + async def _visit_coresdk_workflow_commands_QueryResult(self, fs, o): + if o.HasField("succeeded"): + await self._visit_coresdk_workflow_commands_QuerySuccess(fs, o.succeeded) + elif o.HasField("failed"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failed) + + async def _visit_coresdk_workflow_commands_CompleteWorkflowExecution(self, fs, o): + if o.HasField("result"): + await self._visit_temporal_api_common_v1_Payload(fs, o.result) + + async def _visit_coresdk_workflow_commands_FailWorkflowExecution(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + self, fs, o + ): + await self._visit_payload_container(fs, o.arguments) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_commands_StartChildWorkflowExecution(self, fs, o): + await self._visit_payload_container(fs, o.input) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + for v in o.memo.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + self, fs, o + ): + await self._visit_payload_container(fs, o.args) + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_commands_ScheduleLocalActivity(self, fs, o): + if not self.skip_headers: + for v in o.headers.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + await self._visit_payload_container(fs, o.arguments) + + async def _visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + self, fs, o + ): + for v in o.search_attributes.values(): + await self._visit_temporal_api_common_v1_Payload(fs, v) + + async def _visit_coresdk_workflow_commands_ModifyWorkflowProperties(self, fs, o): + if o.HasField("upserted_memo"): + await self._visit_temporal_api_common_v1_Memo(fs, o.upserted_memo) + + async def _visit_coresdk_workflow_commands_UpdateResponse(self, fs, o): + if o.HasField("rejected"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.rejected) + elif o.HasField("completed"): + await self._visit_temporal_api_common_v1_Payload(fs, o.completed) + + async def _visit_coresdk_workflow_commands_ScheduleNexusOperation(self, fs, o): + if o.HasField("input"): + await self._visit_temporal_api_common_v1_Payload(fs, o.input) + + async def _visit_coresdk_workflow_commands_WorkflowCommand(self, fs, o): + if o.HasField("user_metadata"): + await self._visit_temporal_api_sdk_v1_UserMetadata(fs, o.user_metadata) + if o.HasField("schedule_activity"): + await self._visit_coresdk_workflow_commands_ScheduleActivity( + fs, o.schedule_activity + ) + elif o.HasField("respond_to_query"): + await self._visit_coresdk_workflow_commands_QueryResult( + fs, o.respond_to_query + ) + elif o.HasField("complete_workflow_execution"): + await self._visit_coresdk_workflow_commands_CompleteWorkflowExecution( + fs, o.complete_workflow_execution + ) + elif o.HasField("fail_workflow_execution"): + await self._visit_coresdk_workflow_commands_FailWorkflowExecution( + fs, o.fail_workflow_execution + ) + elif o.HasField("continue_as_new_workflow_execution"): + await self._visit_coresdk_workflow_commands_ContinueAsNewWorkflowExecution( + fs, o.continue_as_new_workflow_execution + ) + elif o.HasField("start_child_workflow_execution"): + await self._visit_coresdk_workflow_commands_StartChildWorkflowExecution( + fs, o.start_child_workflow_execution + ) + elif o.HasField("signal_external_workflow_execution"): + await self._visit_coresdk_workflow_commands_SignalExternalWorkflowExecution( + fs, o.signal_external_workflow_execution + ) + elif o.HasField("schedule_local_activity"): + await self._visit_coresdk_workflow_commands_ScheduleLocalActivity( + fs, o.schedule_local_activity + ) + elif o.HasField("upsert_workflow_search_attributes"): + await self._visit_coresdk_workflow_commands_UpsertWorkflowSearchAttributes( + fs, o.upsert_workflow_search_attributes + ) + elif o.HasField("modify_workflow_properties"): + await self._visit_coresdk_workflow_commands_ModifyWorkflowProperties( + fs, o.modify_workflow_properties + ) + elif o.HasField("update_response"): + await self._visit_coresdk_workflow_commands_UpdateResponse( + fs, o.update_response + ) + elif o.HasField("schedule_nexus_operation"): + await self._visit_coresdk_workflow_commands_ScheduleNexusOperation( + fs, o.schedule_nexus_operation + ) + + async def _visit_coresdk_workflow_completion_Success(self, fs, o): + for v in o.commands: + await self._visit_coresdk_workflow_commands_WorkflowCommand(fs, v) + + async def _visit_coresdk_workflow_completion_Failure(self, fs, o): + if o.HasField("failure"): + await self._visit_temporal_api_failure_v1_Failure(fs, o.failure) + + async def _visit_coresdk_workflow_completion_WorkflowActivationCompletion( + self, fs, o + ): + if o.HasField("successful"): + await self._visit_coresdk_workflow_completion_Success(fs, o.successful) + elif o.HasField("failed"): + await self._visit_coresdk_workflow_completion_Failure(fs, o.failed) diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index e4cb05eee..9b2abed8e 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -12,6 +12,7 @@ Callable, List, Mapping, + MutableSequence, Optional, Sequence, Set, @@ -34,6 +35,8 @@ import temporalio.bridge.temporal_sdk_bridge import temporalio.converter import temporalio.exceptions +from temporalio.api.common.v1.message_pb2 import Payload, Payloads +from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions from temporalio.bridge.temporal_sdk_bridge import ( CustomSlotSupplier as BridgeCustomSlotSupplier, ) @@ -275,108 +278,23 @@ async def finalize_shutdown(self) -> None: await ref.finalize_shutdown() -# See https://mypy.readthedocs.io/en/stable/runtime_troubles.html#using-classes-that-are-generic-in-stubs-but-not-at-runtime -if TYPE_CHECKING: - PayloadContainer: TypeAlias = ( - google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - temporalio.api.common.v1.Payload - ] - ) -else: - PayloadContainer: TypeAlias = ( - google.protobuf.internal.containers.RepeatedCompositeFieldContainer - ) - - -async def _apply_to_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to headers.""" - for payload in headers.values(): - new_payload = (await cb([payload]))[0] - payload.CopyFrom(new_payload) - - -async def _decode_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode headers with the given codec.""" - return await _apply_to_headers(headers, codec.decode) - - -async def _encode_headers( - headers: Mapping[str, temporalio.api.common.v1.Payload], - codec: temporalio.converter.PayloadCodec, -) -> None: - """Encode headers with the given codec.""" - return await _apply_to_headers(headers, codec.encode) +class _Visitor(VisitorFunctions): + def __init__(self, f: Callable[[Sequence[Payload]], Awaitable[List[Payload]]]): + self._f = f + async def visit_payload(self, payload: Payload) -> None: + new_payload = (await self._f([payload]))[0] + if new_payload is not payload: + payload.CopyFrom(new_payload) -async def _apply_to_payloads( - payloads: PayloadContainer, - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to payloads.""" - if len(payloads) == 0: - return - new_payloads = await cb(payloads) - if new_payloads is payloads: - return - del payloads[:] - # TODO(cretz): Copy too expensive? - payloads.extend(new_payloads) - - -async def _apply_to_payload( - payload: temporalio.api.common.v1.Payload, - cb: Callable[ - [Sequence[temporalio.api.common.v1.Payload]], - Awaitable[List[temporalio.api.common.v1.Payload]], - ], -) -> None: - """Apply API payload callback to payload.""" - new_payload = (await cb([payload]))[0] - payload.CopyFrom(new_payload) - - -async def _decode_payloads( - payloads: PayloadContainer, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode payloads with the given codec.""" - return await _apply_to_payloads(payloads, codec.decode) - - -async def _decode_payload( - payload: temporalio.api.common.v1.Payload, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode a payload with the given codec.""" - return await _apply_to_payload(payload, codec.decode) - - -async def _encode_payloads( - payloads: PayloadContainer, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Encode payloads with the given codec.""" - return await _apply_to_payloads(payloads, codec.encode) - - -async def _encode_payload( - payload: temporalio.api.common.v1.Payload, - codec: temporalio.converter.PayloadCodec, -) -> None: - """Decode a payload with the given codec.""" - return await _apply_to_payload(payload, codec.encode) + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + if len(payloads) == 0: + return + new_payloads = await self._f(payloads) + if new_payloads is payloads: + return + del payloads[:] + payloads.extend(new_payloads) async def decode_activation( @@ -385,76 +303,9 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode the given activation with the codec.""" - for job in act.jobs: - if job.HasField("query_workflow"): - await _decode_payloads(job.query_workflow.arguments, codec) - if decode_headers: - await _decode_headers(job.query_workflow.headers, codec) - elif job.HasField("resolve_activity"): - if job.resolve_activity.result.HasField("cancelled"): - await codec.decode_failure( - job.resolve_activity.result.cancelled.failure - ) - elif job.resolve_activity.result.HasField("completed"): - if job.resolve_activity.result.completed.HasField("result"): - await _decode_payload( - job.resolve_activity.result.completed.result, codec - ) - elif job.resolve_activity.result.HasField("failed"): - await codec.decode_failure(job.resolve_activity.result.failed.failure) - elif job.HasField("resolve_child_workflow_execution"): - if job.resolve_child_workflow_execution.result.HasField("cancelled"): - await codec.decode_failure( - job.resolve_child_workflow_execution.result.cancelled.failure - ) - elif job.resolve_child_workflow_execution.result.HasField( - "completed" - ) and job.resolve_child_workflow_execution.result.completed.HasField( - "result" - ): - await _decode_payload( - job.resolve_child_workflow_execution.result.completed.result, codec - ) - elif job.resolve_child_workflow_execution.result.HasField("failed"): - await codec.decode_failure( - job.resolve_child_workflow_execution.result.failed.failure - ) - elif job.HasField("resolve_child_workflow_execution_start"): - if job.resolve_child_workflow_execution_start.HasField("cancelled"): - await codec.decode_failure( - job.resolve_child_workflow_execution_start.cancelled.failure - ) - elif job.HasField("resolve_request_cancel_external_workflow"): - if job.resolve_request_cancel_external_workflow.HasField("failure"): - await codec.decode_failure( - job.resolve_request_cancel_external_workflow.failure - ) - elif job.HasField("resolve_signal_external_workflow"): - if job.resolve_signal_external_workflow.HasField("failure"): - await codec.decode_failure(job.resolve_signal_external_workflow.failure) - elif job.HasField("signal_workflow"): - await _decode_payloads(job.signal_workflow.input, codec) - if decode_headers: - await _decode_headers(job.signal_workflow.headers, codec) - elif job.HasField("initialize_workflow"): - await _decode_payloads(job.initialize_workflow.arguments, codec) - if decode_headers: - await _decode_headers(job.initialize_workflow.headers, codec) - if job.initialize_workflow.HasField("continued_failure"): - await codec.decode_failure(job.initialize_workflow.continued_failure) - for val in job.initialize_workflow.memo.fields.values(): - # This uses API payload not bridge payload - new_payload = (await codec.decode([val]))[0] - # Make a shallow copy, in case new_payload.metadata and val.metadata are - # references to the same memory, e.g. decode() returns its input unchanged. - new_metadata = dict(new_payload.metadata) - val.metadata.clear() - val.metadata.update(new_metadata) - val.data = new_payload.data - elif job.HasField("do_update"): - await _decode_payloads(job.do_update.input, codec) - if decode_headers: - await _decode_headers(job.do_update.headers, codec) + await PayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit(_Visitor(codec.decode), act) async def encode_completion( @@ -463,66 +314,6 @@ async def encode_completion( encode_headers: bool, ) -> None: """Recursively encode the given completion with the codec.""" - if comp.HasField("failed"): - await codec.encode_failure(comp.failed.failure) - elif comp.HasField("successful"): - for command in comp.successful.commands: - if command.HasField("complete_workflow_execution"): - if command.complete_workflow_execution.HasField("result"): - await _encode_payload( - command.complete_workflow_execution.result, codec - ) - elif command.HasField("continue_as_new_workflow_execution"): - await _encode_payloads( - command.continue_as_new_workflow_execution.arguments, codec - ) - if encode_headers: - await _encode_headers( - command.continue_as_new_workflow_execution.headers, codec - ) - for val in command.continue_as_new_workflow_execution.memo.values(): - await _encode_payload(val, codec) - elif command.HasField("fail_workflow_execution"): - await codec.encode_failure(command.fail_workflow_execution.failure) - elif command.HasField("respond_to_query"): - if command.respond_to_query.HasField("failed"): - await codec.encode_failure(command.respond_to_query.failed) - elif command.respond_to_query.HasField( - "succeeded" - ) and command.respond_to_query.succeeded.HasField("response"): - await _encode_payload( - command.respond_to_query.succeeded.response, codec - ) - elif command.HasField("schedule_activity"): - await _encode_payloads(command.schedule_activity.arguments, codec) - if encode_headers: - await _encode_headers(command.schedule_activity.headers, codec) - elif command.HasField("schedule_local_activity"): - await _encode_payloads(command.schedule_local_activity.arguments, codec) - if encode_headers: - await _encode_headers( - command.schedule_local_activity.headers, codec - ) - elif command.HasField("signal_external_workflow_execution"): - await _encode_payloads( - command.signal_external_workflow_execution.args, codec - ) - if encode_headers: - await _encode_headers( - command.signal_external_workflow_execution.headers, codec - ) - elif command.HasField("start_child_workflow_execution"): - await _encode_payloads( - command.start_child_workflow_execution.input, codec - ) - if encode_headers: - await _encode_headers( - command.start_child_workflow_execution.headers, codec - ) - for val in command.start_child_workflow_execution.memo.values(): - await _encode_payload(val, codec) - elif command.HasField("update_response"): - if command.update_response.HasField("completed"): - await _encode_payload(command.update_response.completed, codec) - elif command.update_response.HasField("rejected"): - await codec.encode_failure(command.update_response.rejected) + await PayloadVisitor( + skip_search_attributes=True, skip_headers=not encode_headers + ).visit(_Visitor(codec.encode), comp) diff --git a/temporalio/converter.py b/temporalio/converter.py index 190fda0e6..a9f8c0c98 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -689,11 +689,17 @@ async def decode_wrapper(self, payloads: temporalio.api.common.v1.Payloads) -> N payloads.payloads.extend(new_payloads) async def encode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: - """Encode payloads of a failure.""" + """Encode payloads of a failure. Intended as a helper method, not for overriding. + It is not guaranteed that all failures will be encoded with this method rather + than encoding the underlying payloads. + """ await self._apply_to_failure_payloads(failure, self.encode_wrapper) async def decode_failure(self, failure: temporalio.api.failure.v1.Failure) -> None: - """Decode payloads of a failure.""" + """Decode payloads of a failure. Intended as a helper method, not for overriding. + It is not guaranteed that all failures will be decoded with this method rather + than decoding the underlying payloads. + """ await self._apply_to_failure_payloads(failure, self.decode_wrapper) async def _apply_to_failure_payloads( diff --git a/tests/worker/test_visitor.py b/tests/worker/test_visitor.py new file mode 100644 index 000000000..c59a0248b --- /dev/null +++ b/tests/worker/test_visitor.py @@ -0,0 +1,247 @@ +from typing import MutableSequence + +from google.protobuf.duration_pb2 import Duration + +import temporalio.bridge.worker +from temporalio.api.common.v1.message_pb2 import ( + Payload, + Payloads, + Priority, + SearchAttributes, +) +from temporalio.api.sdk.v1.user_metadata_pb2 import UserMetadata +from temporalio.bridge._visitor import PayloadVisitor, VisitorFunctions +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( + InitializeWorkflow, + WorkflowActivation, + WorkflowActivationJob, +) +from temporalio.bridge.proto.workflow_commands.workflow_commands_pb2 import ( + ContinueAsNewWorkflowExecution, + ScheduleActivity, + ScheduleLocalActivity, + SignalExternalWorkflowExecution, + StartChildWorkflowExecution, + UpdateResponse, + WorkflowCommand, +) +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( + Success, + WorkflowActivationCompletion, +) +from tests.worker.test_workflow import SimpleCodec + + +class Visitor(VisitorFunctions): + async def visit_payload(self, payload: Payload) -> None: + payload.metadata["visited"] = b"True" + + async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None: + for payload in payloads: + payload.metadata["visited"] = b"True" + + +async def test_workflow_activation_completion(): + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=1, + activity_id="1", + activity_type="", + task_queue="", + headers={"foo": Payload(data=b"bar")}, + arguments=[Payload(data=b"baz")], + schedule_to_close_timeout=Duration(seconds=5), + priority=Priority(), + ), + user_metadata=UserMetadata(summary=Payload(data=b"Summary")), + ) + ], + ), + ) + + await PayloadVisitor().visit(Visitor(), comp) + + cmd = comp.successful.commands[0] + sa = cmd.schedule_activity + assert sa.headers["foo"].metadata["visited"] + assert len(sa.arguments) == 1 and sa.arguments[0].metadata["visited"] + + assert cmd.user_metadata.summary.metadata["visited"] + + +async def test_workflow_activation(): + original = WorkflowActivation( + jobs=[ + WorkflowActivationJob( + initialize_workflow=InitializeWorkflow( + arguments=[ + Payload(data=b"repeated1"), + Payload(data=b"repeated2"), + ], + headers={"header": Payload(data=b"map")}, + last_completion_result=Payloads( + payloads=[ + Payload(data=b"obj1"), + Payload(data=b"obj2"), + ] + ), + search_attributes=SearchAttributes( + indexed_fields={ + "sakey": Payload(data=b"saobj"), + } + ), + ), + ) + ] + ) + + async def visitor(payload: Payload) -> Payload: + # Mark visited by prefixing data + new_payload = Payload() + new_payload.metadata.update(payload.metadata) + new_payload.metadata["visited"] = b"True" + new_payload.data = payload.data + return new_payload + + act = original.__deepcopy__() + await PayloadVisitor().visit(Visitor(), act) + assert act.jobs[0].initialize_workflow.arguments[0].metadata["visited"] + assert act.jobs[0].initialize_workflow.arguments[1].metadata["visited"] + assert act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] + assert ( + act.jobs[0] + .initialize_workflow.last_completion_result.payloads[0] + .metadata["visited"] + ) + assert ( + act.jobs[0] + .initialize_workflow.last_completion_result.payloads[1] + .metadata["visited"] + ) + assert ( + act.jobs[0] + .initialize_workflow.search_attributes.indexed_fields["sakey"] + .metadata["visited"] + ) + + act = original.__deepcopy__() + await PayloadVisitor(skip_search_attributes=True).visit(Visitor(), act) + assert ( + not act.jobs[0] + .initialize_workflow.search_attributes.indexed_fields["sakey"] + .metadata["visited"] + ) + + act = original.__deepcopy__() + await PayloadVisitor(skip_headers=True).visit(Visitor(), act) + assert not act.jobs[0].initialize_workflow.headers["header"].metadata["visited"] + + +async def test_visit_payloads_on_other_commands(): + comp = WorkflowActivationCompletion( + run_id="2", + successful=Success( + commands=[ + # Continue as new + WorkflowCommand( + continue_as_new_workflow_execution=ContinueAsNewWorkflowExecution( + arguments=[Payload(data=b"a1")], + headers={"h1": Payload(data=b"a2")}, + memo={"m1": Payload(data=b"a3")}, + ) + ), + # Start child + WorkflowCommand( + start_child_workflow_execution=StartChildWorkflowExecution( + input=[Payload(data=b"b1")], + headers={"h2": Payload(data=b"b2")}, + memo={"m2": Payload(data=b"b3")}, + ) + ), + # Signal external + WorkflowCommand( + signal_external_workflow_execution=SignalExternalWorkflowExecution( + args=[Payload(data=b"c1")], + headers={"h3": Payload(data=b"c2")}, + ) + ), + # Schedule local activity + WorkflowCommand( + schedule_local_activity=ScheduleLocalActivity( + arguments=[Payload(data=b"d1")], + headers={"h4": Payload(data=b"d2")}, + ) + ), + # Update response completed + WorkflowCommand( + update_response=UpdateResponse( + completed=Payload(data=b"e1"), + ) + ), + ] + ), + ) + + await PayloadVisitor().visit(Visitor(), comp) + + cmds = comp.successful.commands + can = cmds[0].continue_as_new_workflow_execution + assert can.arguments[0].metadata["visited"] + assert can.headers["h1"].metadata["visited"] + assert can.memo["m1"].metadata["visited"] + + sc = cmds[1].start_child_workflow_execution + assert sc.input[0].metadata["visited"] + assert sc.headers["h2"].metadata["visited"] + assert sc.memo["m2"].metadata["visited"] + + se = cmds[2].signal_external_workflow_execution + assert se.args[0].metadata["visited"] + assert se.headers["h3"].metadata["visited"] + + sla = cmds[3].schedule_local_activity + assert sla.arguments[0].metadata["visited"] + assert sla.headers["h4"].metadata["visited"] + + ur = cmds[4].update_response + assert ur.completed.metadata["visited"] + + +async def test_bridge_encoding(): + comp = WorkflowActivationCompletion( + run_id="1", + successful=Success( + commands=[ + WorkflowCommand( + schedule_activity=ScheduleActivity( + seq=1, + activity_id="1", + activity_type="", + task_queue="", + headers={"foo": Payload(data=b"bar")}, + arguments=[ + Payload(data=b"repeated1"), + Payload(data=b"repeated2"), + ], + schedule_to_close_timeout=Duration(seconds=5), + priority=Priority(), + ), + user_metadata=UserMetadata(summary=Payload(data=b"Summary")), + ) + ], + ), + ) + + await temporalio.bridge.worker.encode_completion(comp, SimpleCodec(), True) + + cmd = comp.successful.commands[0] + sa = cmd.schedule_activity + assert sa.headers["foo"].metadata["simple-codec"] + assert len(sa.arguments) == 1 + assert sa.arguments[0].metadata["simple-codec"] + + assert cmd.user_metadata.summary.metadata["simple-codec"]