@@ -39,18 +39,18 @@ def emit_loop(
3939
4040
4141def emit_singular (
42- field_name : str , access_expr : str , child_method : str , check_presence : bool
42+ field_name : str , access_expr : str , child_method : str , presence_word : Optional [ str ]
4343) -> str :
4444 # Helper to emit a singular field visit with presence check and optional headers guard
45- if check_presence :
45+ if presence_word :
4646 if field_name == "headers" :
4747 return f"""\
4848 if not self.skip_headers:
49- if o.HasField("{ field_name } "):
49+ { presence_word } o.HasField("{ field_name } "):
5050 await self._visit_{ child_method } (fs, { access_expr } )"""
5151 else :
5252 return f"""\
53- if o.HasField("{ field_name } "):
53+ { presence_word } o.HasField("{ field_name } "):
5454 await self._visit_{ child_method } (fs, { access_expr } )"""
5555 else :
5656 if field_name == "headers" :
@@ -63,6 +63,67 @@ def emit_singular(
6363
6464
6565class VisitorGenerator :
66+ def generate (self , roots : list [Descriptor ]) -> str :
67+ """
68+ Generate Python source code that, given a function f(Payload) -> Payload,
69+ applies it to every Payload contained within a WorkflowActivation tree.
70+
71+ The generated code defines async visitor functions for each reachable
72+ protobuf message type starting from WorkflowActivation, including support
73+ for repeated fields and map entries, and a convenience entrypoint
74+ function `visit`.
75+ """
76+
77+ for r in roots :
78+ self .walk (r )
79+
80+ header = """
81+ # This file is generated by gen_payload_visitor.py. Changes should be made there.
82+ import abc
83+ from typing import Any, MutableSequence
84+
85+ from temporalio.api.common.v1.message_pb2 import Payload
86+
87+ class VisitorFunctions(abc.ABC):
88+ \" \" \" Set of functions which can be called by the visitor.
89+ Allows handling payloads as a sequence.
90+ \" \" \"
91+ @abc.abstractmethod
92+ async def visit_payload(self, payload: Payload) -> None:
93+ \" \" \" Called when encountering a single payload.\" \" \"
94+ raise NotImplementedError()
95+
96+ @abc.abstractmethod
97+ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
98+ \" \" \" Called when encountering multiple payloads together.\" \" \"
99+ raise NotImplementedError()
100+
101+ class PayloadVisitor:
102+ \" \" \" A visitor for payloads.
103+ Applies a function to every payload in a tree of messages.
104+ \" \" \"
105+ def __init__(
106+ self, *, skip_search_attributes: bool = False, skip_headers: bool = False
107+ ):
108+ \" \" \" Creates a new payload visitor.\" \" \"
109+ self.skip_search_attributes = skip_search_attributes
110+ self.skip_headers = skip_headers
111+
112+ async def visit(
113+ self, fs: VisitorFunctions, root: Any
114+ ) -> None:
115+ \" \" \" Visits the given root message with the given function.\" \" \"
116+ method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
117+ method = getattr(self, method_name, None)
118+ if method is not None:
119+ await method(fs, root)
120+ else:
121+ raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
122+
123+ """
124+
125+ return header + "\n " .join (self .methods )
126+
66127 def __init__ (self ):
67128 # Track which message descriptors have visitor methods generated
68129 self .generated : dict [str , bool ] = {
@@ -71,21 +132,24 @@ def __init__(self):
71132 }
72133 self .in_progress : set [str ] = set ()
73134 self .methods : list [str ] = [
74- """ async def _visit_temporal_api_common_v1_Payload(self, fs, o):
75- await fs.visit_payload(o)
135+ """\
136+ async def _visit_temporal_api_common_v1_Payload(self, fs, o):
137+ await fs.visit_payload(o)
76138 """ ,
77- """ async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
78- await fs.visit_payloads(o.payloads)
139+ """\
140+ async def _visit_temporal_api_common_v1_Payloads(self, fs, o):
141+ await fs.visit_payloads(o.payloads)
79142 """ ,
80- """ async def _visit_payload_container(self, fs, o):
81- await fs.visit_payloads(o)
143+ """\
144+ async def _visit_payload_container(self, fs, o):
145+ await fs.visit_payloads(o)
82146 """ ,
83147 ]
84148
85149 def check_repeated (self , child_desc , field , iter_expr ) -> Optional [str ]:
86150 # Special case for repeated payloads, handle them directly
87151 if child_desc .full_name == Payload .DESCRIPTOR .full_name :
88- return emit_singular (field .name , iter_expr , "payload_container" , False )
152+ return emit_singular (field .name , iter_expr , "payload_container" , None )
89153 else :
90154 child_needed = self .walk (child_desc )
91155 if child_needed :
@@ -105,7 +169,7 @@ def walk(self, desc: Descriptor) -> bool:
105169 # Break cycles; if another path proves this node needed, we'll revisit
106170 return False
107171
108- needed = False
172+ has_payload = False
109173 self .in_progress .add (key )
110174 lines : list [str ] = [f" async def _visit_{ name_for (desc )} (self, fs, o):" ]
111175 # If this is the SearchAttributes message, allow skipping
@@ -146,7 +210,7 @@ def walk(self, desc: Descriptor) -> bool:
146210 child_desc = val_fd .message_type
147211 child_needed = self .walk (child_desc )
148212 if child_needed :
149- needed = True
213+ has_payload = True
150214 lines .append (
151215 emit_loop (
152216 field .name ,
@@ -163,7 +227,7 @@ def walk(self, desc: Descriptor) -> bool:
163227 child_desc = key_fd .message_type
164228 child_needed = self .walk (child_desc )
165229 if child_needed :
166- needed = True
230+ has_payload = True
167231 lines .append (
168232 emit_loop (
169233 field .name ,
@@ -176,16 +240,16 @@ def walk(self, desc: Descriptor) -> bool:
176240 field .message_type , field , f"o.{ field .name } "
177241 )
178242 if child is not None :
179- needed = True
243+ has_payload = True
180244 lines .append (child )
181245 else :
182246 child_desc = field .message_type
183- child_needed = self .walk (child_desc )
184- needed |= child_needed
185- if child_needed :
247+ child_has_payload = self .walk (child_desc )
248+ has_payload |= child_has_payload
249+ if child_has_payload :
186250 lines .append (
187251 emit_singular (
188- field .name , f"o.{ field .name } " , name_for (child_desc ), True
252+ field .name , f"o.{ field .name } " , name_for (child_desc ), "if"
189253 )
190254 )
191255
@@ -195,81 +259,23 @@ def walk(self, desc: Descriptor) -> bool:
195259 first = True
196260 for field in fields :
197261 child_desc = field .message_type
198- child_needed = self .walk (child_desc )
199- needed |= child_needed
200- if child_needed :
262+ child_has_payload = self .walk (child_desc )
263+ has_payload |= child_has_payload
264+ if child_has_payload :
201265 if_word = "if" if first else "elif"
202266 first = False
203267 line = emit_singular (
204- field .name , f"o.{ field .name } " , name_for (child_desc ), True
205- ). replace ( " if" , f" { if_word } " , 1 )
268+ field .name , f"o.{ field .name } " , name_for (child_desc ), if_word
269+ )
206270 oneof_lines .append (line )
207271 if oneof_lines :
208272 lines .extend (oneof_lines )
209273
210- self .generated [key ] = needed
274+ self .generated [key ] = has_payload
211275 self .in_progress .discard (key )
212- if needed :
276+ if has_payload :
213277 self .methods .append ("\n " .join (lines ) + "\n " )
214- return needed
215-
216- def generate (self , roots : list [Descriptor ]) -> str :
217- """
218- Generate Python source code that, given a function f(Payload) -> Payload,
219- applies it to every Payload contained within a WorkflowActivation tree.
220-
221- The generated code defines async visitor functions for each reachable
222- protobuf message type starting from WorkflowActivation, including support
223- for repeated fields and map entries, and a convenience entrypoint
224- function `visit`.
225- """
226-
227- # We avoid importing google.api deps in service protos; expand by walking from
228- # WorkflowActivationCompletion root which references many command messages.
229- for r in roots :
230- self .walk (r )
231-
232- header = """
233- import abc
234- from typing import Any, MutableSequence
235-
236- from temporalio.api.common.v1.message_pb2 import Payload
237-
238- class VisitorFunctions(abc.ABC):
239- \" \" \" Set of functions which can be called by the visitor. Allows handling payloads as a sequence.\" \" \"
240- @abc.abstractmethod
241- async def visit_payload(self, payload: Payload) -> None:
242- \" \" \" Called when encountering a single payload.\" \" \"
243- raise NotImplementedError()
244-
245- @abc.abstractmethod
246- async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
247- \" \" \" Called when encountering multiple payloads together.\" \" \"
248- raise NotImplementedError()
249-
250- class PayloadVisitor:
251- \" \" \" A visitor for payloads. Applies a function to every payload in a tree of messages.\" \" \"
252- def __init__(
253- self, *, skip_search_attributes: bool = False, skip_headers: bool = False
254- ):
255- \" \" \" Creates a new payload visitor.\" \" \"
256- self.skip_search_attributes = skip_search_attributes
257- self.skip_headers = skip_headers
258-
259- async def visit(
260- self, fs: VisitorFunctions, root: Any
261- ) -> None:
262- \" \" \" Visits the given root message with the given function.\" \" \"
263- method_name = "_visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
264- method = getattr(self, method_name, None)
265- if method is not None:
266- await method(fs, root)
267- else:
268- raise ValueError(f"Unknown root message type: {root.DESCRIPTOR.full_name}")
269-
270- """
271-
272- return header + "\n " .join (self .methods )
278+ return has_payload
273279
274280
275281def write_generated_visitors_into_visitor_generated_py () -> None :
0 commit comments