1+ import subprocess
12import sys
23from pathlib import Path
34
@@ -30,38 +31,33 @@ def name_for(desc: Descriptor) -> str:
3031 return desc .full_name .replace ("." , "_" )
3132
3233 def emit_loop (
33- lines : list [str ],
3434 field_name : str ,
3535 iter_expr : str ,
3636 var_name : str ,
3737 child_method : str ,
38- ) -> None :
38+ ) -> str :
3939 # Helper to emit a for-loop over a collection with optional headers guard
4040 if field_name == "headers" :
41- lines .append (" if not self.skip_headers:" )
42- lines .append (f" for { var_name } in { iter_expr } :" )
43- lines .append (
44- f" await self.visit_{ child_method } (f, { var_name } )"
45- )
41+ return f"""\
42+ if not self.skip_headers:
43+ for { var_name } in { iter_expr } :
44+ await self._visit_{ child_method } (f, { var_name } )"""
4645 else :
47- lines .append (f" for { var_name } in { iter_expr } :" )
48- lines .append (f" await self.visit_{ child_method } (f, { var_name } )" )
46+ return f"""\
47+ for { var_name } in { iter_expr } :
48+ await self._visit_{ child_method } (f, { var_name } )"""
4949
50- def emit_singular (
51- lines : list [str ], field_name : str , access_expr : str , child_method : str
52- ) -> None :
50+ def emit_singular (field_name : str , access_expr : str , child_method : str ) -> str :
5351 # Helper to emit a singular field visit with presence check and optional headers guard
5452 if field_name == "headers" :
55- lines .append (" if not self.skip_headers:" )
56- lines .append (f" if o.HasField('{ field_name } '):" )
57- lines .append (
58- f" await self.visit_{ child_method } (f, { access_expr } )"
59- )
53+ return f"""\
54+ if not self.skip_headers:
55+ if o.HasField("{ field_name } "):
56+ await self._visit_{ child_method } (f, { access_expr } )"""
6057 else :
61- lines .append (f" if o.HasField('{ field_name } '):" )
62- lines .append (
63- f" await self.visit_{ child_method } (f, { access_expr } )"
64- )
58+ return f"""\
59+ if o.HasField("{ field_name } "):
60+ await self._visit_{ child_method } (f, { access_expr } )"""
6561
6662 # Track which message descriptors have visitor methods generated
6763 generated : dict [str , bool ] = {}
@@ -79,15 +75,15 @@ def walk(desc: Descriptor) -> bool:
7975 if desc .full_name == Payload .DESCRIPTOR .full_name :
8076 generated [key ] = True
8177 methods .append (
82- """ async def visit_temporal_api_common_v1_Payload (self, f, o):
78+ """ async def _visit_temporal_api_common_v1_Payload (self, f, o):
8379 o.CopyFrom(await f(o))
8480"""
8581 )
8682 return True
8783
8884 needed = False
8985 in_progress .add (key )
90- lines : list [str ] = [f" async def visit_ { name_for (desc )} (self, f, o):" ]
86+ lines : list [str ] = [f" async def _visit_ { name_for (desc )} (self, f, o):" ]
9187 # If this is the SearchAttributes message, allow skipping
9288 if desc .full_name == SearchAttributes .DESCRIPTOR .full_name :
9389 lines .append (" if self.skip_search_attributes:" )
@@ -115,12 +111,13 @@ def walk(desc: Descriptor) -> bool:
115111 child_needed = walk (child_desc )
116112 needed |= child_needed
117113 if child_needed :
118- emit_loop (
119- lines ,
120- field .name ,
121- f"o.{ field .name } .values()" ,
122- "v" ,
123- name_for (child_desc ),
114+ lines .append (
115+ emit_loop (
116+ field .name ,
117+ f"o.{ field .name } .values()" ,
118+ "v" ,
119+ name_for (child_desc ),
120+ )
124121 )
125122
126123 if (
@@ -131,32 +128,36 @@ def walk(desc: Descriptor) -> bool:
131128 child_needed = walk (key_desc )
132129 needed |= child_needed
133130 if child_needed :
134- emit_loop (
135- lines ,
136- field .name ,
137- f"o.{ field .name } .keys()" ,
138- "k" ,
139- name_for (key_desc ),
131+ lines .append (
132+ emit_loop (
133+ field .name ,
134+ f"o.{ field .name } .keys()" ,
135+ "k" ,
136+ name_for (key_desc ),
137+ )
140138 )
141139 else :
142140 child_desc = field .message_type
143141 child_needed = walk (child_desc )
144142 needed |= child_needed
145143 if child_needed :
146- emit_loop (
147- lines ,
148- field .name ,
149- f"o.{ field .name } " ,
150- "v" ,
151- name_for (child_desc ),
144+ lines .append (
145+ emit_loop (
146+ field .name ,
147+ f"o.{ field .name } " ,
148+ "v" ,
149+ name_for (child_desc ),
150+ )
152151 )
153152 else :
154153 child_desc = field .message_type
155154 child_needed = walk (child_desc )
156155 needed |= child_needed
157156 if child_needed :
158- emit_singular (
159- lines , field .name , f"o.{ field .name } " , name_for (child_desc )
157+ lines .append (
158+ emit_singular (
159+ field .name , f"o.{ field .name } " , name_for (child_desc )
160+ )
160161 )
161162
162163 generated [key ] = needed
@@ -178,30 +179,43 @@ def walk(desc: Descriptor) -> bool:
178179 for r in roots :
179180 walk (r )
180181
181- header = (
182- "from typing import Any, Awaitable, Callable\n \n "
183- "from temporalio.api.common.v1.message_pb2 import Payload\n \n \n "
184- "class PayloadVisitor:\n "
185- " def __init__(self, *, skip_search_attributes: bool = False, skip_headers: bool = False):\n "
186- " self.skip_search_attributes = skip_search_attributes\n "
187- " self.skip_headers = skip_headers\n \n "
188- " async def visit(self, f: Callable[[Payload], Awaitable[Payload]], root: Any) -> None:\n "
189- " method_name = 'visit_' + root.DESCRIPTOR.full_name.replace('.', '_')\n "
190- " method = getattr(self, method_name, None)\n "
191- " if method is not None:\n "
192- " await method(f, root)\n \n "
193- )
182+ header = """
183+ from typing import Any, Awaitable, Callable
184+
185+ from temporalio.api.common.v1.message_pb2 import Payload
186+
187+
188+ class PayloadVisitor:
189+ \" \" \" A visitor for payloads. Applies a function to every payload in a tree of messages.\" \" \"
190+ def __init__(
191+ self, *, skip_search_attributes: bool = False, skip_headers: bool = False
192+ ):
193+ \" \" \" Creates a new payload visitor.\" \" \"
194+ self.skip_search_attributes = skip_search_attributes
195+ self.skip_headers = skip_headers
196+
197+ async def visit(
198+ self, f: Callable[[Payload], Awaitable[Payload]], root: Any
199+ ) -> None:
200+ \" \" \" Visits the given root message with the given function.\" \" \"
201+ method_name = "visit_" + root.DESCRIPTOR.full_name.replace(".", "_")
202+ method = getattr(self, method_name, None)
203+ if method is not None:
204+ await method(f, root)
205+
206+ """
194207
195208 return header + "\n " .join (methods )
196209
197210
198211def write_generated_visitors_into_visitor_generated_py () -> None :
199- """Write the generated visitor code into visitor_generated .py."""
200- out_path = base_dir / "temporalio" / "bridge" / "visitor_generated .py"
212+ """Write the generated visitor code into _visitor .py."""
213+ out_path = base_dir / "temporalio" / "bridge" / "_visitor .py"
201214 code = gen_workflow_activation_payload_visitor_code ()
202215 out_path .write_text (code )
203216
204217
205218if __name__ == "__main__" :
206- print ("Generating temporalio/bridge/visitor_generated .py..." , file = sys .stderr )
219+ print ("Generating temporalio/bridge/_visitor .py..." , file = sys .stderr )
207220 write_generated_visitors_into_visitor_generated_py ()
221+ subprocess .run (["uv" , "run" , "ruff" , "format" , "temporalio/bridge/_visitor.py" ])
0 commit comments