11import collections
22import os .path
33import tempfile
4+ from textwrap import dedent
45from typing import DefaultDict , List , Sequence
56
67import black
@@ -220,20 +221,27 @@ def message_encoder(
220221
221222def generate_river_module (
222223 module_name : str ,
224+ pb_module_name : str ,
223225 fds : descriptor_pb2 .FileDescriptorSet ,
224226) -> Sequence [str ]:
225227 """Generates the lines of a River module."""
226228 chunks : List [str ] = [
227- "# Code generated by river.codegen. DO NOT EDIT." ,
228- "import datetime" ,
229- "from typing import Any, Dict, Mapping, Tuple" ,
230- "" ,
231- "from google.protobuf import timestamp_pb2" ,
232- "from google.protobuf.wrappers_pb2 import BoolValue" ,
229+ dedent (
230+ f"""\
231+ # Code generated by river.codegen. DO NOT EDIT.
232+ import datetime
233+ from typing import Any, Dict, Mapping, Tuple
234+
235+ from google.protobuf import timestamp_pb2
236+ from google.protobuf.wrappers_pb2 import BoolValue
237+
238+ import replit_river as river
239+
240+ from { module_name } import { pb_module_name } _pb2, { pb_module_name } _pb2_grpc
241+ """
242+ ),
233243 "" ,
234- "import replit_river as river" ,
235244 "" ,
236- f"from . import { module_name } _pb2, { module_name } _pb2_grpc\n \n " ,
237245 ]
238246 for pd in fds .file :
239247
@@ -242,15 +250,15 @@ def _remove_namespace(name: str) -> str:
242250
243251 # Generate the message encoders/decoders.
244252 for message in pd .message_type :
245- chunks .extend (message_encoder (module_name , message ))
246- chunks .extend (message_decoder (module_name , message ))
253+ chunks .extend (message_encoder (pb_module_name , message ))
254+ chunks .extend (message_decoder (pb_module_name , message ))
247255
248256 # Generate the service stubs.
249257 for service in pd .service :
250258 chunks .extend (
251259 [
252260 f"""def add_{ service .name } Servicer_to_server(
253- servicer: { module_name } _pb2_grpc.{ service .name } Servicer,
261+ servicer: { pb_module_name } _pb2_grpc.{ service .name } Servicer,
254262 server: river.Server,
255263 ) -> None:""" ,
256264 (
@@ -301,7 +309,11 @@ def _remove_namespace(name: str) -> str:
301309 return chunks
302310
303311
304- def proto_to_river_server_codegen (proto_path : str , target_directory : str ) -> None :
312+ def proto_to_river_server_codegen (
313+ module_name : str ,
314+ proto_path : str ,
315+ target_directory : str ,
316+ ) -> None :
305317 fds = descriptor_pb2 .FileDescriptorSet ()
306318 with tempfile .TemporaryDirectory () as tempdir :
307319 descriptor_path = os .path .join (tempdir , "descriptor.pb" )
@@ -317,12 +329,12 @@ def proto_to_river_server_codegen(proto_path: str, target_directory: str) -> Non
317329 )
318330 with open (descriptor_path , "rb" ) as f :
319331 fds .ParseFromString (f .read ())
320- module_name = os .path .splitext (os .path .basename (proto_path ))[0 ]
332+ pb_module_name = os .path .splitext (os .path .basename (proto_path ))[0 ]
321333 contents = black .format_str (
322- "\n " .join (generate_river_module (module_name , fds )),
334+ "\n " .join (generate_river_module (module_name , pb_module_name , fds )),
323335 mode = black .FileMode (string_normalization = False ),
324336 )
325337 os .makedirs (target_directory , exist_ok = True )
326- output_path = f"{ target_directory } /{ module_name } _river.py"
338+ output_path = f"{ target_directory } /{ pb_module_name } _river.py"
327339 with open (output_path , "w" ) as f :
328340 f .write (contents )
0 commit comments