1+ import base64
12import io
3+ import mimetypes
24from pathlib import Path
35from types import GeneratorType
4- from typing import Any , Callable
6+ from typing import TYPE_CHECKING , Any , Optional
7+
8+ if TYPE_CHECKING :
9+ from replicate .client import Client
10+ from replicate .file import FileEncodingStrategy
11+
512
613try :
714 import numpy as np # type: ignore
1421# pylint: disable=too-many-return-statements
1522def encode_json (
1623 obj : Any , # noqa: ANN401
17- upload_file : Callable [[io .IOBase ], str ],
24+ client : "Client" ,
25+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
1826) -> Any : # noqa: ANN401
1927 """
2028 Return a JSON-compatible version of the object.
2129 """
22- # Effectively the same thing as cog.json.encode_json.
2330
2431 if isinstance (obj , dict ):
25- return {key : encode_json (value , upload_file ) for key , value in obj .items ()}
32+ return {
33+ key : encode_json (value , client , file_encoding_strategy )
34+ for key , value in obj .items ()
35+ }
36+ if isinstance (obj , (list , set , frozenset , GeneratorType , tuple )):
37+ return [encode_json (value , client , file_encoding_strategy ) for value in obj ]
38+ if isinstance (obj , Path ):
39+ with obj .open ("rb" ) as file :
40+ return encode_json (file , client , file_encoding_strategy )
41+ if isinstance (obj , io .IOBase ):
42+ if file_encoding_strategy == "base64" :
43+ return base64 .b64encode (obj .read ()).decode ("utf-8" )
44+ else :
45+ return client .files .create (obj ).urls ["get" ]
46+ if HAS_NUMPY :
47+ if isinstance (obj , np .integer ): # type: ignore
48+ return int (obj )
49+ if isinstance (obj , np .floating ): # type: ignore
50+ return float (obj )
51+ if isinstance (obj , np .ndarray ): # type: ignore
52+ return obj .tolist ()
53+ return obj
54+
55+
56+ async def async_encode_json (
57+ obj : Any , # noqa: ANN401
58+ client : "Client" ,
59+ file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
60+ ) -> Any : # noqa: ANN401
61+ """
62+ Asynchronously return a JSON-compatible version of the object.
63+ """
64+
65+ if isinstance (obj , dict ):
66+ return {
67+ key : (await async_encode_json (value , client , file_encoding_strategy ))
68+ for key , value in obj .items ()
69+ }
2670 if isinstance (obj , (list , set , frozenset , GeneratorType , tuple )):
27- return [encode_json (value , upload_file ) for value in obj ]
71+ return [
72+ (await async_encode_json (value , client , file_encoding_strategy ))
73+ for value in obj
74+ ]
2875 if isinstance (obj , Path ):
2976 with obj .open ("rb" ) as file :
30- return upload_file (file )
77+ return encode_json (file , client , file_encoding_strategy )
3178 if isinstance (obj , io .IOBase ):
32- return upload_file ( obj )
79+ return ( await client . files . async_create ( obj )). urls [ "get" ]
3380 if HAS_NUMPY :
3481 if isinstance (obj , np .integer ): # type: ignore
3582 return int (obj )
@@ -38,3 +85,26 @@ def encode_json(
3885 if isinstance (obj , np .ndarray ): # type: ignore
3986 return obj .tolist ()
4087 return obj
88+
89+
90+ def base64_encode_file (file : io .IOBase ) -> str :
91+ """
92+ Base64 encode a file.
93+
94+ Args:
95+ file: A file handle to upload.
96+ Returns:
97+ str: A base64-encoded data URI.
98+ """
99+
100+ file .seek (0 )
101+ body = file .read ()
102+
103+ # Ensure the file handle is in bytes
104+ body = body .encode ("utf-8" ) if isinstance (body , str ) else body
105+ encoded_body = base64 .b64encode (body ).decode ("utf-8" )
106+
107+ mime_type = (
108+ mimetypes .guess_type (getattr (file , "name" , "" ))[0 ] or "application/octet-stream"
109+ )
110+ return f"data:{ mime_type } ;base64,{ encoded_body } "
0 commit comments