44from pathlib import Path
55from textwrap import dedent , indent
66from typing import (
7- Any ,
87 Dict ,
98 List ,
109 Literal ,
@@ -180,15 +179,7 @@ class RiverIntersectionType(BaseModel):
180179 allOf : List ["RiverType" ]
181180
182181
183- class RiverNotType (BaseModel ):
184- """This is used to represent void / never."""
185-
186- not_ : Any = Field (..., alias = "not" )
187-
188-
189- RiverType = Union [
190- RiverConcreteType , RiverUnionType , RiverNotType , RiverIntersectionType
191- ]
182+ RiverType = Union [RiverConcreteType , RiverUnionType , RiverIntersectionType ]
192183
193184
194185class RiverProcedure (BaseModel ):
@@ -239,8 +230,6 @@ def encode_type(
239230) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
240231 encoder_name : Optional [str ] = None # defining this up here to placate mypy
241232 chunks : List [FileContents ] = []
242- if isinstance (type , RiverNotType ):
243- return (TypeName ("None" ), [], [], set ())
244233 if isinstance (type , RiverUnionType ):
245234 typeddict_encoder = list [str ]()
246235 encoder_names : set [TypeName ] = set ()
@@ -352,7 +341,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
352341 else :
353342 local_discriminator = "FIXME: Ambiguous discriminators"
354343 typeddict_encoder .append (
355- f" if ' { local_discriminator } ' in x else "
344+ f" if { repr ( local_discriminator ) } in x else "
356345 )
357346 typeddict_encoder .pop () # Drop the last ternary
358347 typeddict_encoder .append (")" )
@@ -372,8 +361,8 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
372361 typeddict_encoder .append (f"{ encoder_name } (x)" )
373362 typeddict_encoder .append (
374363 f"""
375- if x[' { discriminator_name } ' ]
376- == ' { discriminator_value } '
364+ if x[{ repr ( discriminator_name ) } ]
365+ == { repr ( discriminator_value ) }
377366 else
378367 """ ,
379368 )
@@ -393,7 +382,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
393382 [
394383 dedent (
395384 f"""\
396- { encoder_name } : Callable[[' { prefix } ' ], Any] = (
385+ { encoder_name } : Callable[[{ repr ( prefix ) } ], Any] = (
397386 lambda x:
398387 """ .rstrip ()
399388 )
@@ -450,14 +439,17 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
450439 chunks .append (
451440 FileContents (
452441 "\n " .join (
453- [f"{ encoder_name } : Callable[['{ prefix } '], Any] = (lambda x: " ]
442+ [
443+ f"{ encoder_name } : Callable[[{ repr (prefix )} ], Any] = ("
444+ "lambda x: "
445+ ]
454446 + typeddict_encoder
455447 + [")" ]
456448 )
457449 )
458450 )
459451 return (prefix , in_module , chunks , encoder_names )
460- if isinstance (type , RiverIntersectionType ):
452+ elif isinstance (type , RiverIntersectionType ):
461453
462454 def extract_props (tpe : RiverType ) -> list [dict [str , RiverType ]]:
463455 if isinstance (tpe , RiverUnionType ):
@@ -478,15 +470,17 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
478470 base_model ,
479471 in_module ,
480472 )
481- if isinstance (type , RiverConcreteType ):
473+ elif isinstance (type , RiverConcreteType ):
482474 typeddict_encoder = list [str ]()
483475 if type .type is None :
484476 # Handle the case where type is not specified
485477 typeddict_encoder .append ("x" )
486478 return (TypeName ("Any" ), [], [], set ())
479+ elif type .type == "not" :
480+ return (TypeName ("None" ), [], [], set ())
487481 elif type .type == "string" :
488482 if type .const :
489- typeddict_encoder .append (f"' { type .const } '" )
483+ typeddict_encoder .append (repr ( type .const ) )
490484 return (LiteralTypeExpr (type .const ), [], [], set ())
491485 else :
492486 typeddict_encoder .append ("x" )
@@ -565,48 +559,48 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
565559 name ,
566560 prop ,
567561 ) in sorted (list (type .properties .items ()), key = lambda xs : xs [0 ]):
568- typeddict_encoder .append (f"' { name } ' :" )
562+ typeddict_encoder .append (f"{ repr ( name ) } :" )
569563 type_name , _ , contents , _ = encode_type (
570564 prop , TypeName (prefix + name .title ()), base_model , in_module
571565 )
572566 encoder_name = None
573567 chunks .extend (contents )
574568 if base_model == "TypedDict" :
575- if isinstance (prop , RiverNotType ):
576- typeddict_encoder .append ("'not implemented'" )
577- elif isinstance (prop , RiverUnionType ):
569+ if isinstance (prop , RiverUnionType ):
578570 encoder_name = TypeName (
579571 f"encode_{ ensure_literal_type (type_name )} "
580572 )
581573 encoder_names .add (encoder_name )
582- typeddict_encoder .append (f"{ encoder_name } (x[' { name } ' ])" )
574+ typeddict_encoder .append (f"{ encoder_name } (x[{ repr ( name ) } ])" )
583575 if name not in type .required :
584- typeddict_encoder .append (f"if x[' { name } ' ] else None" )
576+ typeddict_encoder .append (f"if x[{ repr ( name ) } ] else None" )
585577 elif isinstance (prop , RiverIntersectionType ):
586578 encoder_name = TypeName (
587579 f"encode_{ ensure_literal_type (type_name )} "
588580 )
589581 encoder_names .add (encoder_name )
590- typeddict_encoder .append (f"{ encoder_name } (x[' { name } ' ])" )
582+ typeddict_encoder .append (f"{ encoder_name } (x[{ repr ( name ) } ])" )
591583 elif isinstance (prop , RiverConcreteType ):
592584 if name == "$kind" :
593585 safe_name = "kind"
594586 else :
595587 safe_name = name
596- if prop .type == "object" and not prop .patternProperties :
588+ if prop .type == "not" :
589+ typeddict_encoder .append ("'not implemented'" )
590+ elif prop .type == "object" and not prop .patternProperties :
597591 encoder_name = TypeName (
598592 f"encode_{ ensure_literal_type (type_name )} "
599593 )
600594 encoder_names .add (encoder_name )
601595 typeddict_encoder .append (
602- f"{ encoder_name } (x[' { safe_name } ' ])"
596+ f"{ encoder_name } (x[{ repr ( safe_name ) } ])"
603597 )
604598 if name not in prop .required :
605599 typeddict_encoder .append (
606600 dedent (
607601 f"""
608- if ' { safe_name } ' in x
609- and x[' { safe_name } ' ] is not None
602+ if { repr ( safe_name ) } in x
603+ and x[{ repr ( safe_name ) } ] is not None
610604 else None
611605 """
612606 )
@@ -615,7 +609,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
615609 items = cast (RiverConcreteType , prop ).items
616610 assert items , "Somehow items was none"
617611 if is_literal (cast (RiverType , items )):
618- typeddict_encoder .append (f"x[' { name } ' ]" )
612+ typeddict_encoder .append (f"x[{ repr ( name ) } ]" )
619613 else :
620614 match type_name :
621615 case ListTypeExpr (inner_type_name ):
@@ -628,16 +622,16 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
628622 f"""\
629623 [
630624 { encoder_name } (y)
631- for y in x[' { name } ' ]
625+ for y in x[{ repr ( name ) } ]
632626 ]
633627 """ .rstrip ()
634628 )
635629 )
636630 else :
637631 if name in prop .required :
638- typeddict_encoder .append (f"x[' { safe_name } ' ]" )
632+ typeddict_encoder .append (f"x[{ repr ( safe_name ) } ]" )
639633 else :
640- typeddict_encoder .append (f"x.get(' { safe_name } ' )" )
634+ typeddict_encoder .append (f"x.get({ repr ( safe_name ) } )" )
641635
642636 if name == "$kind" :
643637 # If the field is a literal, the Python type-checker will complain
@@ -657,7 +651,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
657651 f"""\
658652 = Field(
659653 default=None,
660- alias=' { name } ' , # type: ignore
654+ alias={ repr ( name ) } , # type: ignore
661655 )
662656 """
663657 )
@@ -671,7 +665,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
671665 f"""\
672666 = Field(
673667 { field_value } ,
674- alias=' { name } ' , # type: ignore
668+ alias={ repr ( name ) } , # type: ignore
675669 )
676670 """
677671 )
@@ -714,7 +708,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
714708 [
715709 dedent (
716710 f"""\
717- { encoder_name } : Callable[[' { prefix } ' ], Any] = (
711+ { encoder_name } : Callable[[{ repr ( prefix ) } ], Any] = (
718712 lambda { binding } :
719713 """
720714 )
@@ -781,7 +775,7 @@ def __init__(self, client: river.Client[Any]):
781775 module_names = [ModuleName (name )]
782776 init_type : Optional [TypeExpression ] = None
783777 if procedure .init :
784- init_type , module_info , input_chunks , encoder_names = encode_type (
778+ init_type , module_info , init_chunks , encoder_names = encode_type (
785779 procedure .init ,
786780 TypeName (f"{ name .title ()} Init" ),
787781 input_base_class ,
@@ -791,7 +785,7 @@ def __init__(self, client: river.Client[Any]):
791785 (
792786 [extract_inner_type (init_type ), * encoder_names ],
793787 module_info ,
794- input_chunks ,
788+ init_chunks ,
795789 )
796790 )
797791 input_type , module_info , input_chunks , encoder_names = encode_type (
@@ -856,31 +850,28 @@ def __init__(self, client: river.Client[Any]):
856850
857851 # Init renderer
858852 render_init_method : Optional [str ] = None
859- if input_base_class == "TypedDict" and init_type :
860- if is_literal (procedure .input ):
861- render_init_method = "lambda x: x"
862- elif isinstance (
863- procedure .input , RiverConcreteType
864- ) and procedure .input .type in ["array" ]:
865- match init_type :
866- case ListTypeExpr (init_type_name ):
867- render_init_method = (
868- f"lambda xs: [encode_{ init_type_name } (x) for x in xs]"
869- )
853+ if init_type and procedure .init is not None :
854+ if input_base_class == "TypedDict" :
855+ if is_literal (procedure .init ):
856+ render_init_method = "lambda x: x"
857+ elif isinstance (
858+ procedure .init , RiverConcreteType
859+ ) and procedure .init .type in ["array" ]:
860+ match init_type :
861+ case ListTypeExpr (init_type_name ):
862+ render_init_method = (
863+ f"lambda xs: [encode_{ init_type_name } (x) for x in xs]"
864+ )
865+ else :
866+ render_init_method = f"encode_{ ensure_literal_type (init_type )} "
870867 else :
871- render_init_method = f"encode_{ ensure_literal_type (init_type )} "
872- else :
873- render_init_method = f"""\
874- lambda x: TypeAdapter({ render_type_expr (input_type )} )
875- .validate_python
876- """
877- if isinstance (
878- procedure .init , RiverConcreteType
879- ) and procedure .init .type not in ["object" , "array" ]:
880- render_init_method = "lambda x: x"
868+ render_init_method = f"""\
869+ lambda x: TypeAdapter({ render_type_expr (init_type )} )
870+ .validate_python
871+ """
881872
882873 assert (
883- render_init_method
874+ init_type is None or render_init_method
884875 ), f"Unable to derive the init encoder from: { input_type } "
885876
886877 # Input renderer
@@ -922,10 +913,6 @@ def __init__(self, client: river.Client[Any]):
922913 parse_output_method = "lambda x: None"
923914
924915 if procedure .type == "rpc" :
925- control_flow_keyword = "return "
926- if output_type == "None" :
927- control_flow_keyword = ""
928-
929916 current_chunks .extend (
930917 [
931918 reindent (
@@ -935,9 +922,9 @@ async def {name}(
935922 self,
936923 input: { render_type_expr (input_type )} ,
937924 ) -> { render_type_expr (output_type )} :
938- { control_flow_keyword } await self.client.send_rpc(
939- ' { schema_name } ' ,
940- ' { name } ' ,
925+ return await self.client.send_rpc(
926+ { repr ( schema_name ) } ,
927+ { repr ( name ) } ,
941928 input,
942929 { reindent (" " , render_input_method )} ,
943930 { reindent (" " , parse_output_method )} ,
@@ -958,8 +945,8 @@ async def {name}(
958945 input: { render_type_expr (input_type )} ,
959946 ) -> AsyncIterator[{ render_type_expr (output_or_error_type )} ]:
960947 return await self.client.send_subscription(
961- ' { schema_name } ' ,
962- ' { name } ' ,
948+ { repr ( schema_name ) } ,
949+ { repr ( name ) } ,
963950 input,
964951 { reindent (" " , render_input_method )} ,
965952 { reindent (" " , parse_output_method )} ,
@@ -970,10 +957,8 @@ async def {name}(
970957 ]
971958 )
972959 elif procedure .type == "upload" :
973- control_flow_keyword = "return "
974- if output_type == "None" :
975- control_flow_keyword = ""
976960 if init_type :
961+ assert render_init_method , "Expected an init renderer!"
977962 current_chunks .extend (
978963 [
979964 reindent (
@@ -984,9 +969,9 @@ async def {name}(
984969 init: { init_type } ,
985970 inputStream: AsyncIterable[{ render_type_expr (input_type )} ],
986971 ) -> { output_type } :
987- { control_flow_keyword } await self.client.send_upload(
988- ' { schema_name } ' ,
989- ' { name } ' ,
972+ return await self.client.send_upload(
973+ { repr ( schema_name ) } ,
974+ { repr ( name ) } ,
990975 init,
991976 inputStream,
992977 { reindent (" " , render_init_method )} ,
@@ -1008,9 +993,9 @@ async def {name}(
1008993 self,
1009994 inputStream: AsyncIterable[{ render_type_expr (input_type )} ],
1010995 ) -> { render_type_expr (output_or_error_type )} :
1011- { control_flow_keyword } await self.client.send_upload(
1012- ' { schema_name } ' ,
1013- ' { name } ' ,
996+ return await self.client.send_upload(
997+ { repr ( schema_name ) } ,
998+ { repr ( name ) } ,
1014999 None,
10151000 inputStream,
10161001 None,
@@ -1024,6 +1009,7 @@ async def {name}(
10241009 )
10251010 elif procedure .type == "stream" :
10261011 if init_type :
1012+ assert render_init_method , "Expected an init renderer!"
10271013 current_chunks .extend (
10281014 [
10291015 reindent (
@@ -1035,8 +1021,8 @@ async def {name}(
10351021 inputStream: AsyncIterable[{ render_type_expr (input_type )} ],
10361022 ) -> AsyncIterator[{ render_type_expr (output_or_error_type )} ]:
10371023 return await self.client.send_stream(
1038- ' { schema_name } ' ,
1039- ' { name } ' ,
1024+ { repr ( schema_name ) } ,
1025+ { repr ( name ) } ,
10401026 init,
10411027 inputStream,
10421028 { reindent (" " , render_init_method )} ,
@@ -1059,8 +1045,8 @@ async def {name}(
10591045 inputStream: AsyncIterable[{ render_type_expr (input_type )} ],
10601046 ) -> AsyncIterator[{ render_type_expr (output_or_error_type )} ]:
10611047 return await self.client.send_stream(
1062- ' { schema_name } ' ,
1063- ' { name } ' ,
1048+ { repr ( schema_name ) } ,
1049+ { repr ( name ) } ,
10641050 None,
10651051 inputStream,
10661052 None,
0 commit comments