@@ -839,8 +839,9 @@ def __init__(self, client: river.Client[Any]):
839839 for name , procedure in schema .procedures .items ():
840840 module_names = [ModuleName (name )]
841841 init_type : TypeExpression | None = None
842+ init_module_info : list [ModuleName ] = []
842843 if procedure .init :
843- init_type , module_info , init_chunks , encoder_names = encode_type (
844+ init_type , init_module_info , init_chunks , encoder_names = encode_type (
844845 procedure .init ,
845846 TypeName (f"{ name .title ()} Init" ),
846847 input_base_class ,
@@ -850,34 +851,29 @@ def __init__(self, client: river.Client[Any]):
850851 serdes .append (
851852 (
852853 [extract_inner_type (init_type ), * encoder_names ],
853- module_info ,
854+ init_module_info ,
854855 init_chunks ,
855856 )
856857 )
857- input_type , module_info , input_chunks , encoder_names = encode_type (
858- procedure .input ,
859- TypeName (f"{ name .title ()} Input" ),
860- input_base_class ,
861- module_names ,
862- permit_unknown_members = False ,
863- )
864- input_type_name = extract_inner_type (input_type )
865- input_type_type_adapter_name = TypeName (
866- f"{ render_literal_type (input_type_name )} TypeAdapter"
867- )
868- serdes .append (
869- (
870- [extract_inner_type (input_type ), * encoder_names ],
871- module_info ,
872- input_chunks ,
858+ input_type : TypeExpression | None = None
859+ input_module_info : list [ModuleName ] = []
860+ if procedure .input :
861+ input_type , input_module_info , input_chunks , encoder_names = encode_type (
862+ procedure .input ,
863+ TypeName (f"{ name .title ()} Input" ),
864+ input_base_class ,
865+ module_names ,
866+ permit_unknown_members = False ,
873867 )
874- )
875- serdes .append (
876- _type_adapter_definition (
877- input_type_type_adapter_name , input_type , module_info
868+ serdes .append (
869+ (
870+ [extract_inner_type (input_type ), * encoder_names ],
871+ input_module_info ,
872+ input_chunks ,
873+ )
878874 )
879- )
880- output_type , module_info , output_chunks , encoder_names = encode_type (
875+
876+ output_type , output_module_info , output_chunks , encoder_names = encode_type (
881877 procedure .output ,
882878 TypeName (f"{ name .title ()} Output" ),
883879 "BaseModel" ,
@@ -888,7 +884,7 @@ def __init__(self, client: river.Client[Any]):
888884 serdes .append (
889885 (
890886 [output_type_name , * encoder_names ],
891- module_info ,
887+ output_module_info ,
892888 output_chunks ,
893889 )
894890 )
@@ -897,12 +893,12 @@ def __init__(self, client: river.Client[Any]):
897893 )
898894 serdes .append (
899895 _type_adapter_definition (
900- output_type_type_adapter_name , output_type , module_info
896+ output_type_type_adapter_name , output_type , output_module_info
901897 )
902898 )
903- output_module_info = module_info
899+
904900 if procedure .errors :
905- error_type , module_info , errors_chunks , encoder_names = encode_type (
901+ error_type , error_module_info , errors_chunks , encoder_names = encode_type (
906902 procedure .errors ,
907903 TypeName (f"{ name .title ()} Errors" ),
908904 "RiverError" ,
@@ -914,7 +910,7 @@ def __init__(self, client: river.Client[Any]):
914910 error_type = error_type_name
915911 else :
916912 error_type_name = extract_inner_type (error_type )
917- serdes .append (([error_type_name ], module_info , errors_chunks ))
913+ serdes .append (([error_type_name ], error_module_info , errors_chunks ))
918914
919915 else :
920916 error_type_name = TypeName ("RiverError" )
@@ -924,11 +920,9 @@ def __init__(self, client: river.Client[Any]):
924920 f"{ render_literal_type (error_type_name )} TypeAdapter"
925921 )
926922 if error_type_type_adapter_name .value != "RiverErrorTypeAdapter" :
927- if len (module_info ) == 0 :
928- module_info = output_module_info
929923 serdes .append (
930924 _type_adapter_definition (
931- error_type_type_adapter_name , error_type , module_info
925+ error_type_type_adapter_name , error_type , output_module_info
932926 )
933927 )
934928 output_or_error_type = UnionTypeExpr ([output_type , error_type_name ])
@@ -975,48 +969,50 @@ def __init__(self, client: river.Client[Any]):
975969 )
976970 serdes .append (
977971 _type_adapter_definition (
978- init_type_type_adapter_name , init_type , module_info
972+ init_type_type_adapter_name , init_type , init_module_info
979973 )
980974 )
981975 render_init_method = f"""\
982976 lambda x: { render_type_expr (init_type_type_adapter_name )}
983977 .validate_python
984978 """
985979
986- assert init_type is None or render_init_method , (
987- f"Unable to derive the init encoder from: { init_type } "
988- )
989-
990980 # Input renderer
991981 render_input_method : str | None = None
992- if input_base_class == "TypedDict" :
993- if is_literal (procedure .input ):
994- render_input_method = "lambda x: x"
995- elif isinstance (
996- procedure .input , RiverConcreteType
997- ) and procedure .input .type in ["array" ]:
998- match input_type :
999- case ListTypeExpr (list_type ):
1000- render_input_method = f"""\
1001- lambda xs: [
1002- encode_{ render_literal_type (list_type )} (x) for x in xs
1003- ]
1004- """
1005- else :
1006- render_input_method = f"encode_{ render_literal_type (input_type )} "
1007- else :
1008- render_input_method = f"""\
1009- lambda x: { render_type_expr (input_type_type_adapter_name )}
1010- .dump_python(
1011- x, # type: ignore[arg-type]
1012- by_alias=True,
1013- exclude_none=True,
1014- )
982+ if input_type and procedure .input is not None :
983+ if input_base_class == "TypedDict" :
984+ if is_literal (procedure .input ):
985+ render_input_method = "lambda x: x"
986+ elif isinstance (
987+ procedure .input , RiverConcreteType
988+ ) and procedure .input .type in ["array" ]:
989+ match input_type :
990+ case ListTypeExpr (list_type ):
991+ render_input_method = f"""\
992+ lambda xs: [
993+ encode_{ render_literal_type (list_type )} (x) for x in xs
994+ ]
1015995 """
1016-
1017- assert render_input_method , (
1018- f"Unable to derive the input encoder from: { input_type } "
1019- )
996+ else :
997+ render_input_method = f"encode_{ render_literal_type (input_type )} "
998+ else :
999+ input_type_name = extract_inner_type (input_type )
1000+ input_type_type_adapter_name = TypeName (
1001+ f"{ render_literal_type (input_type_name )} TypeAdapter"
1002+ )
1003+ serdes .append (
1004+ _type_adapter_definition (
1005+ input_type_type_adapter_name , input_type , input_module_info
1006+ )
1007+ )
1008+ render_input_method = f"""\
1009+ lambda x: { render_type_expr (input_type_type_adapter_name )}
1010+ .dump_python(
1011+ x, # type: ignore[arg-type]
1012+ by_alias=True,
1013+ exclude_none=True,
1014+ )
1015+ """
10201016
10211017 if isinstance (output_type , NoneTypeExpr ):
10221018 parse_output_method = "lambda x: None"
@@ -1075,7 +1071,6 @@ async def {name}(
10751071 )
10761072 elif procedure .type == "upload" :
10771073 if init_type :
1078- assert render_init_method , "Expected an init renderer!"
10791074 current_chunks .extend (
10801075 [
10811076 reindent (
@@ -1132,7 +1127,6 @@ async def {name}(
11321127 ]
11331128 )
11341129 if init_type :
1135- assert render_init_method , "Expected an init renderer!"
11361130 current_chunks .extend (
11371131 [
11381132 reindent (
0 commit comments