@@ -109,6 +109,32 @@ def ensure_literal_type(value: TypeExpression) -> TypeName:
109109
110110_NON_ALNUM_RE = re .compile (r"[^a-zA-Z0-9_]+" )
111111
112+ # Literal is here because HandshakeType can be Literal[None]
113+ ROOT_FILE_HEADER = dedent (
114+ """\
115+ # Code generated by river.codegen. DO NOT EDIT.
116+ from pydantic import BaseModel
117+ from typing import Literal
118+
119+ import replit_river as river
120+
121+ """
122+ )
123+
124+ SERVICE_FILE_HEADER = dedent (
125+ """\
126+ # Code generated by river.codegen. DO NOT EDIT.
127+ from collections.abc import AsyncIterable, AsyncIterator
128+ from typing import Any
129+
130+ from pydantic import TypeAdapter
131+
132+ from replit_river.error_schema import RiverError
133+ import replit_river as river
134+
135+ """
136+ )
137+
112138FILE_HEADER = dedent (
113139 """\
114140 # ruff: noqa
@@ -709,7 +735,7 @@ def generate_common_client(
709735 handshake_chunks : Sequence [str ],
710736 modules : list [Tuple [ModuleName , ClassName ]],
711737) -> FileContents :
712- chunks : list [str ] = [FILE_HEADER ]
738+ chunks : list [str ] = [ROOT_FILE_HEADER ]
713739 chunks .extend (
714740 [
715741 f"from .{ model_name } import { class_name } "
@@ -1072,7 +1098,7 @@ async def {name}(
10721098 ]
10731099
10741100 emitted_files [RenderedPath (str (Path (f"{ schema_name } /__init__.py" )))] = FileContents (
1075- "\n " .join ([FILE_HEADER ] + rendered_imports + in_root + current_chunks )
1101+ "\n " .join ([SERVICE_FILE_HEADER ] + rendered_imports + in_root + current_chunks )
10761102 )
10771103 return (
10781104 ModuleName (schema_name ),
0 commit comments