33import io
44import base64
55import mimetypes
6- from typing import Any , Iterator , AsyncIterator
6+ from types import GeneratorType
7+ from typing import TYPE_CHECKING , Any , Literal , Iterator , Optional , AsyncIterator
8+ from pathlib import Path
79from typing_extensions import override
810
911import httpx
1012
1113from replicate .types .prediction_output import PredictionOutput
1214
1315from .._utils import is_mapping , is_sequence
14- from .._client import ReplicateClient , AsyncReplicateClient
16+
17+ # Use TYPE_CHECKING to avoid circular imports
18+ if TYPE_CHECKING :
19+ from .._client import ReplicateClient , AsyncReplicateClient
20+
21+ FileEncodingStrategy = Literal ["base64" , "url" ]
22+
23+
24+ try :
25+ import numpy as np # type: ignore
26+
27+ HAS_NUMPY = True
28+ except ImportError :
29+ HAS_NUMPY = False # type: ignore
30+
31+
32+ # pylint: disable=too-many-return-statements
33+ def encode_json (
34+ obj : Any , # noqa: ANN401
35+ client : ReplicateClient ,
36+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
37+ ) -> Any : # noqa: ANN401
38+ """
39+ Return a JSON-compatible version of the object.
40+ """
41+
42+ if isinstance (obj , dict ):
43+ return {
44+ key : encode_json (value , client , file_encoding_strategy )
45+ for key , value in obj .items () # type: ignore
46+ } # type: ignore
47+ if isinstance (obj , (list , set , frozenset , GeneratorType , tuple )):
48+ return [encode_json (value , client , file_encoding_strategy ) for value in obj ] # type: ignore
49+ if isinstance (obj , Path ):
50+ with obj .open ("rb" ) as file :
51+ return encode_json (file , client , file_encoding_strategy )
52+ if isinstance (obj , io .IOBase ):
53+ if file_encoding_strategy == "base64" :
54+ return base64_encode_file (obj )
55+ else :
56+ # todo: support files endpoint
57+ # return client.files.create(obj).urls["get"]
58+ raise NotImplementedError ("File upload is not supported yet" )
59+ if HAS_NUMPY :
60+ if isinstance (obj , np .integer ): # type: ignore
61+ return int (obj )
62+ if isinstance (obj , np .floating ): # type: ignore
63+ return float (obj )
64+ if isinstance (obj , np .ndarray ): # type: ignore
65+ return obj .tolist ()
66+ return obj
67+
68+
69+ async def async_encode_json (
70+ obj : Any , # noqa: ANN401
71+ client : AsyncReplicateClient ,
72+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
73+ ) -> Any : # noqa: ANN401
74+ """
75+ Asynchronously return a JSON-compatible version of the object.
76+ """
77+
78+ if isinstance (obj , dict ):
79+ return {
80+ key : (await async_encode_json (value , client , file_encoding_strategy ))
81+ for key , value in obj .items () # type: ignore
82+ } # type: ignore
83+ if isinstance (obj , (list , set , frozenset , GeneratorType , tuple )):
84+ return [
85+ (await async_encode_json (value , client , file_encoding_strategy ))
86+ for value in obj # type: ignore
87+ ]
88+ if isinstance (obj , Path ):
89+ with obj .open ("rb" ) as file :
90+ return await async_encode_json (file , client , file_encoding_strategy )
91+ if isinstance (obj , io .IOBase ):
92+ if file_encoding_strategy == "base64" :
93+ # TODO: This should ideally use an async based file reader path.
94+ return base64_encode_file (obj )
95+ else :
96+ # todo: support files endpoint
97+ # return (await client.files.async_create(obj)).urls["get"]
98+ raise NotImplementedError ("File upload is not supported yet" )
99+ if HAS_NUMPY :
100+ if isinstance (obj , np .integer ): # type: ignore
101+ return int (obj )
102+ if isinstance (obj , np .floating ): # type: ignore
103+ return float (obj )
104+ if isinstance (obj , np .ndarray ): # type: ignore
105+ return obj .tolist ()
106+ return obj
15107
16108
17109def base64_encode_file (file : io .IOBase ) -> str :
@@ -126,7 +218,7 @@ def __repr__(self) -> str:
126218 return f'{ self .__class__ .__name__ } ("{ self .url } ")'
127219
128220
129- def transform_output (value : PredictionOutput , client : ReplicateClient | AsyncReplicateClient ) -> Any :
221+ def transform_output (value : PredictionOutput , client : " ReplicateClient | AsyncReplicateClient" ) -> Any :
130222 """
131223 Transform the output of a prediction to a `FileOutput` object if it's a URL.
132224 """
@@ -137,9 +229,11 @@ def transform(obj: Any) -> Any:
137229 elif is_sequence (obj ) and not isinstance (obj , str ):
138230 return [transform (item ) for item in obj ]
139231 elif isinstance (obj , str ) and (obj .startswith ("https:" ) or obj .startswith ("data:" )):
140- if isinstance (client , AsyncReplicateClient ):
141- return AsyncFileOutput (obj , client )
142- return FileOutput (obj , client )
232+ # Check if the client is async by looking for async in the class name
233+ # we're doing this to avoid circular imports
234+ if "Async" in client .__class__ .__name__ :
235+ return AsyncFileOutput (obj , client ) # type: ignore
236+ return FileOutput (obj , client ) # type: ignore
143237 return obj
144238
145239 return transform (value )
0 commit comments