|
| 1 | +import sys |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +from google.protobuf.descriptor import Descriptor, FieldDescriptor |
| 5 | + |
| 6 | +from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes |
| 7 | +from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import ( |
| 8 | + WorkflowActivation, |
| 9 | +) |
| 10 | +from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import ( |
| 11 | + WorkflowActivationCompletion, |
| 12 | +) |
| 13 | + |
| 14 | +base_dir = Path(__file__).parent.parent |
| 15 | + |
| 16 | +def gen_workflow_activation_payload_visitor_code() -> str: |
| 17 | + """ |
| 18 | + Generate Python source code that, given a function f(Payload) -> Payload, |
| 19 | + applies it to every Payload contained within a WorkflowActivation tree. |
| 20 | +
|
| 21 | + The generated code defines async visitor functions for each reachable |
| 22 | + protobuf message type starting from WorkflowActivation, including support |
| 23 | + for repeated fields and map entries, and a convenience entrypoint |
| 24 | + function `visit_workflow_activation_payloads`. |
| 25 | + """ |
| 26 | + def name_for(desc: Descriptor) -> str: |
| 27 | + # Use fully-qualified name to avoid collisions; replace dots with underscores |
| 28 | + return desc.full_name.replace('.', '_') |
| 29 | + |
| 30 | + def emit_loop(lines: list[str], field_name: str, iter_expr: str, var_name: str, child_method: str) -> None: |
| 31 | + # Helper to emit a for-loop over a collection with optional headers guard |
| 32 | + if field_name == "headers": |
| 33 | + lines.append(" if not self.skip_headers:") |
| 34 | + lines.append(f" for {var_name} in {iter_expr}:") |
| 35 | + lines.append(f" await self.visit_{child_method}(f, {var_name})") |
| 36 | + else: |
| 37 | + lines.append(f" for {var_name} in {iter_expr}:") |
| 38 | + lines.append(f" await self.visit_{child_method}(f, {var_name})") |
| 39 | + |
| 40 | + def emit_singular(lines: list[str], field_name: str, access_expr: str, child_method: str) -> None: |
| 41 | + # Helper to emit a singular field visit with presence check and optional headers guard |
| 42 | + if field_name == "headers": |
| 43 | + lines.append(" if not self.skip_headers:") |
| 44 | + lines.append(f" if o.HasField('{field_name}'):") |
| 45 | + lines.append(f" await self.visit_{child_method}(f, {access_expr})") |
| 46 | + else: |
| 47 | + lines.append(f" if o.HasField('{field_name}'):") |
| 48 | + lines.append(f" await self.visit_{child_method}(f, {access_expr})") |
| 49 | + |
| 50 | + # Track which message descriptors have visitor methods generated |
| 51 | + generated: dict[str, bool] = {} |
| 52 | + in_progress: set[str] = set() |
| 53 | + methods: list[str] = [] |
| 54 | + |
| 55 | + def walk(desc: Descriptor) -> bool: |
| 56 | + key = desc.full_name |
| 57 | + if key in generated: |
| 58 | + return generated[key] |
| 59 | + if key in in_progress: |
| 60 | + # Break cycles; if another path proves this node needed, we'll revisit |
| 61 | + return False |
| 62 | + |
| 63 | + if desc.full_name == Payload.DESCRIPTOR.full_name: |
| 64 | + generated[key] = True |
| 65 | + methods.append( |
| 66 | + """ async def visit_temporal_api_common_v1_Payload(self, f, o): |
| 67 | + o.CopyFrom(await f(o)) |
| 68 | +""" |
| 69 | + ) |
| 70 | + return True |
| 71 | + |
| 72 | + needed = False |
| 73 | + in_progress.add(key) |
| 74 | + lines: list[str] = [f" async def visit_{name_for(desc)}(self, f, o):"] |
| 75 | + # If this is the SearchAttributes message, allow skipping |
| 76 | + if desc.full_name == SearchAttributes.DESCRIPTOR.full_name: |
| 77 | + lines.append(" if self.skip_search_attributes:") |
| 78 | + lines.append(" return") |
| 79 | + |
| 80 | + for field in desc.fields: |
| 81 | + if field.type != FieldDescriptor.TYPE_MESSAGE: |
| 82 | + continue |
| 83 | + |
| 84 | + # Repeated fields (including maps which are represented as repeated messages) |
| 85 | + if field.label == FieldDescriptor.LABEL_REPEATED: |
| 86 | + if field.message_type is not None and field.message_type.GetOptions().map_entry: |
| 87 | + entry_desc = field.message_type |
| 88 | + key_fd = entry_desc.fields_by_name.get("key") |
| 89 | + val_fd = entry_desc.fields_by_name.get("value") |
| 90 | + |
| 91 | + if val_fd is not None and val_fd.type == FieldDescriptor.TYPE_MESSAGE: |
| 92 | + child_desc = val_fd.message_type |
| 93 | + child_needed = walk(child_desc) |
| 94 | + needed |= child_needed |
| 95 | + if child_needed: |
| 96 | + emit_loop(lines, field.name, f"o.{field.name}.values()", "v", name_for(child_desc)) |
| 97 | + |
| 98 | + if key_fd is not None and key_fd.type == FieldDescriptor.TYPE_MESSAGE: |
| 99 | + key_desc = key_fd.message_type |
| 100 | + child_needed = walk(key_desc) |
| 101 | + needed |= child_needed |
| 102 | + if child_needed: |
| 103 | + emit_loop(lines, field.name, f"o.{field.name}.keys()", "k", name_for(key_desc)) |
| 104 | + else: |
| 105 | + child_desc = field.message_type |
| 106 | + child_needed = walk(child_desc) |
| 107 | + needed |= child_needed |
| 108 | + if child_needed: |
| 109 | + emit_loop(lines, field.name, f"o.{field.name}", "v", name_for(child_desc)) |
| 110 | + else: |
| 111 | + child_desc = field.message_type |
| 112 | + child_needed = walk(child_desc) |
| 113 | + needed |= child_needed |
| 114 | + if child_needed: |
| 115 | + emit_singular(lines, field.name, f"o.{field.name}", name_for(child_desc)) |
| 116 | + |
| 117 | + generated[key] = needed |
| 118 | + in_progress.discard(key) |
| 119 | + if needed: |
| 120 | + methods.append("\n".join(lines) + "\n") |
| 121 | + return needed |
| 122 | + |
| 123 | + # Build root descriptors: WorkflowActivation, WorkflowActivationCompletion, |
| 124 | + # and all messages from selected API modules |
| 125 | + roots: list[Descriptor] = [ |
| 126 | + WorkflowActivation.DESCRIPTOR, |
| 127 | + WorkflowActivationCompletion.DESCRIPTOR, |
| 128 | + ] |
| 129 | + |
| 130 | + # We avoid importing google.api deps in service protos; expand by walking from |
| 131 | + # WorkflowActivationCompletion root which references many command messages. |
| 132 | + |
| 133 | + for r in roots: |
| 134 | + walk(r) |
| 135 | + |
| 136 | + header = ( |
| 137 | + "from typing import Awaitable, Callable, Any\n\n" |
| 138 | + "from temporalio.api.common.v1.message_pb2 import Payload\n\n\n" |
| 139 | + "class PayloadVisitor:\n" |
| 140 | + " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n" |
| 141 | + " self.skip_search_attributes = skip_search_attributes\n" |
| 142 | + " self.skip_headers = skip_headers\n\n" |
| 143 | + " async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n" |
| 144 | + " method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n" |
| 145 | + " method = getattr(self, method_name, None)\n" |
| 146 | + " if method is not None:\n" |
| 147 | + " await method(f, root)\n\n" |
| 148 | + ) |
| 149 | + |
| 150 | + return header + "\n".join(methods) |
| 151 | + |
| 152 | + |
| 153 | +def write_generated_visitors_into_visitor_generated_py() -> None: |
| 154 | + """Write the generated visitor code into visitor_generated.py.""" |
| 155 | + out_path = base_dir / "temporalio" / "bridge" / "visitor_generated.py" |
| 156 | + code = gen_workflow_activation_payload_visitor_code() |
| 157 | + out_path.write_text(code) |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + print("Generating temporalio/bridge/visitor_generated.py...", file=sys.stderr) |
| 161 | + write_generated_visitors_into_visitor_generated_py() |
| 162 | + |
0 commit comments