Skip to content
Merged
2 changes: 1 addition & 1 deletion scripts/gen_protos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
from functools import partial
from pathlib import Path
from typing import List, Mapping, Optional
from typing import List, Mapping

base_dir = Path(__file__).parent.parent
proto_dir = (
Expand Down
260 changes: 260 additions & 0 deletions scripts/gen_visitors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import subprocess
import sys
from pathlib import Path
from typing import Optional, Tuple

from google.protobuf.descriptor import Descriptor, FieldDescriptor

from temporalio.api.common.v1.message_pb2 import Payload, Payloads, SearchAttributes
from temporalio.bridge.proto.workflow_activation.workflow_activation_pb2 import (
WorkflowActivation,
)
from temporalio.bridge.proto.workflow_completion.workflow_completion_pb2 import (
WorkflowActivationCompletion,
)

base_dir = Path(__file__).parent.parent


def name_for(desc: Descriptor) -> str:
# Use fully-qualified name to avoid collisions; replace dots with underscores
return desc.full_name.replace(".", "_")


def emit_loop(
field_name: str,
iter_expr: str,
child_method: str,
) -> str:
# Helper to emit a for-loop over a collection with optional headers guard
if field_name == "headers":
return f"""\
if not self.skip_headers:
for v in {iter_expr}:
await self._visit_{child_method}(fs, v)"""
else:
return f"""\
for v in {iter_expr}:
await self._visit_{child_method}(fs, v)"""


def emit_singular(
field_name: str, access_expr: str, child_method: str, check_presence: bool
) -> str:
# Helper to emit a singular field visit with presence check and optional headers guard
if check_presence:
if field_name == "headers":
return f"""\
if not self.skip_headers:
if o.HasField("{field_name}"):
await self._visit_{child_method}(fs, {access_expr})"""
else:
return f"""\
if o.HasField("{field_name}"):
await self._visit_{child_method}(fs, {access_expr})"""
else:
if field_name == "headers":
return f"""\
if not self.skip_headers:
await self._visit_{child_method}(fs, {access_expr})"""
else:
return f"""\
await self._visit_{child_method}(fs, {access_expr})"""


class VisitorGenerator:
def __init__(self):
# Track which message descriptors have visitor methods generated
self.generated: dict[str, bool] = {
Payload.DESCRIPTOR.full_name: True,
Payloads.DESCRIPTOR.full_name: True,
}
self.in_progress: set[str] = set()
self.methods: list[str] = [
""" async def _visit_temporal_api_common_v1_Payload(self, fs, o):
await fs.visit_payload(o)
""",
""" async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
await fs.visit_payloads(o.payloads)
""",
""" async def _visit_payload_container(self, fs, o):
await fs.visit_payloads(o)
""",
]

def check_repeated(self, child_desc, field, iter_expr) -> Optional[str]:
# Special case for repeated payloads, handle them directly
if child_desc.full_name == Payload.DESCRIPTOR.full_name:
return emit_singular(field.name, iter_expr, "payload_container", False)
else:
child_needed = self.walk(child_desc)
if child_needed:
return emit_loop(
field.name,
iter_expr,
name_for(child_desc),
)
else:
return None

def walk(self, desc: Descriptor) -> bool:
key = desc.full_name
if key in self.generated:
return self.generated[key]
if key in self.in_progress:
# Break cycles; if another path proves this node needed, we'll revisit
return False

needed = False
self.in_progress.add(key)
lines: list[str] = [f" async def _visit_{name_for(desc)}(self, fs, o):"]
# If this is the SearchAttributes message, allow skipping
if desc.full_name == SearchAttributes.DESCRIPTOR.full_name:
lines.append(" if self.skip_search_attributes:")
lines.append(" return")

for field in desc.fields:
if field.type != FieldDescriptor.TYPE_MESSAGE:
continue

# Repeated fields (including maps which are represented as repeated messages)
if field.label == FieldDescriptor.LABEL_REPEATED:
if (
field.message_type is not None
and field.message_type.GetOptions().map_entry
):
val_fd = field.message_type.fields_by_name.get("value")
if (
val_fd is not None
and val_fd.type == FieldDescriptor.TYPE_MESSAGE
):
child_desc = val_fd.message_type
child_needed = self.walk(child_desc)
if child_needed:
needed = True
lines.append(
emit_loop(
field.name,
f"o.{field.name}.values()",
name_for(child_desc),
)
)

key_fd = field.message_type.fields_by_name.get("key")
if (
key_fd is not None
and key_fd.type == FieldDescriptor.TYPE_MESSAGE
):
child_desc = key_fd.message_type
child_needed = self.walk(child_desc)
if child_needed:
needed = True
lines.append(
emit_loop(
field.name,
f"o.{field.name}.keys()",
name_for(child_desc),
)
)
else:
child = self.check_repeated(
field.message_type, field, f"o.{field.name}"
)
if child is not None:
needed = True
lines.append(child)
else:
child_desc = field.message_type
child_needed = self.walk(child_desc)
needed |= child_needed
if child_needed:
lines.append(
emit_singular(
field.name, f"o.{field.name}", name_for(child_desc), True
)
)

self.generated[key] = needed
self.in_progress.discard(key)
if needed:
self.methods.append("\n".join(lines) + "\n")
return needed

def generate(self, roots: list[Descriptor]) -> str:
"""
Generate Python source code that, given a function f(Payload) -> Payload,
applies it to every Payload contained within a WorkflowActivation tree.

The generated code defines async visitor functions for each reachable
protobuf message type starting from WorkflowActivation, including support
for repeated fields and map entries, and a convenience entrypoint
function `visit`.
"""

# We avoid importing google.api deps in service protos; expand by walking from
# WorkflowActivationCompletion root which references many command messages.
for r in roots:
self.walk(r)

header = """
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want something pretty visible here indicating that it's generated code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fair

import abc
from typing import Any, MutableSequence

from temporalio.api.common.v1.message_pb2 import Payload

class VisitorFunctions(abc.ABC):
\"\"\"Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\"\"\"
Copy link
Contributor

@dandavison dandavison Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python docstring law dictates that the first line shall have one sentence only and any subsequent sentences shall be separated by a newline from the first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would maybe call it docstring suggestion if the docstring linter doesn't care, which it seemingly doesn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah not sure why that is. Regardless, it's the style we follow for public docstrings, so it's less distracting to just follow it everywhere. It's pretty ubiquitous

https://peps.python.org/pep-0008/#documentation-strings
https://peps.python.org/pep-0257/#multi-line-docstrings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't trying to say I wouldn't

@abc.abstractmethod
async def visit_payload(self, payload: Payload) -> None:
\"\"\"Called when encountering a single payload.\"\"\"
raise NotImplementedError()

@abc.abstractmethod
async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
\"\"\"Called when encountering multiple payloads together.\"\"\"
raise NotImplementedError()

class PayloadVisitor:
\"\"\"A visitor for payloads. Applies a function to every payload in a tree of messages.\"\"\"
def __init__(
self, *, skip_search_attributes: bool = False, skip_headers: bool = False
):
\"\"\"Creates a new payload visitor.\"\"\"
self.skip_search_attributes = skip_search_attributes
self.skip_headers = skip_headers

async def visit(
self, fs: VisitorFunctions, root: Any
) -> None:
\"\"\"Visits the given root message with the given function.\"\"\"
method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
method = getattr(self, method_name, None)
if method is not None:
await method(fs, root)
else:
raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")

"""

return header + "\n".join(self.methods)


def write_generated_visitors_into_visitor_generated_py() -> None:
"""Write the generated visitor code into _visitor.py."""
out_path = base_dir / "temporalio" / "bridge" / "_visitor.py"

# Build root descriptors: WorkflowActivation, WorkflowActivationCompletion,
# and all messages from selected API modules
roots: list[Descriptor] = [
WorkflowActivation.DESCRIPTOR,
WorkflowActivationCompletion.DESCRIPTOR,
]

code = VisitorGenerator().generate(roots)
out_path.write_text(code)


if __name__ == "__main__":
print("Generating temporalio/bridge/_visitor.py...", file=sys.stderr)
write_generated_visitors_into_visitor_generated_py()
subprocess.run(["uv", "run", "ruff", "format", "temporalio/bridge/_visitor.py"])
Loading
Loading