Skip to content

Commit eaf0a07

Browse files
committed
implement encode_json for input and thread through
1 parent 82fe905 commit eaf0a07

File tree

7 files changed

+171
-30
lines changed

7 files changed

+171
-30
lines changed

src/replicate/_client.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import Any, Union, Mapping
6+
from typing import Any, Union, Mapping, Optional
77
from typing_extensions import Self, Unpack, override
88

99
import httpx
1010

11+
from replicate.lib._files import FileEncodingStrategy
1112
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
1213
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
1314

@@ -130,14 +131,22 @@ def run(
130131
self,
131132
ref: Union[Model, Version, ModelVersionIdentifier, str],
132133
*,
134+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
133135
use_file_output: bool = True,
134136
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
135137
**params: Unpack[PredictionCreateParamsWithoutVersion],
136138
) -> Any:
137139
"""Run a model and wait for its output."""
138140
from .lib._predictions import run
139141

140-
return run(self, ref, wait=wait, use_file_output=use_file_output, **params)
142+
return run(
143+
self,
144+
ref,
145+
wait=wait,
146+
use_file_output=use_file_output,
147+
file_encoding_strategy=file_encoding_strategy,
148+
**params,
149+
)
141150

142151
@property
143152
@override
@@ -327,13 +336,21 @@ async def run(
327336
ref: Union[Model, Version, ModelVersionIdentifier, str],
328337
*,
329338
use_file_output: bool = True,
339+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
330340
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
331341
**params: Unpack[PredictionCreateParamsWithoutVersion],
332342
) -> Any:
333343
"""Run a model and wait for its output."""
334344
from .lib._predictions import async_run
335345

336-
return await async_run(self, ref, wait=wait, use_file_output=use_file_output, **params)
346+
return await async_run(
347+
self,
348+
ref,
349+
wait=wait,
350+
use_file_output=use_file_output,
351+
file_encoding_strategy=file_encoding_strategy,
352+
**params,
353+
)
337354

338355
@property
339356
@override

src/replicate/lib/_files.py

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,107 @@
33
import io
44
import base64
55
import mimetypes
6-
from typing import Any, Iterator, AsyncIterator
6+
from types import GeneratorType
7+
from typing import TYPE_CHECKING, Any, Literal, Iterator, Optional, AsyncIterator
8+
from pathlib import Path
79
from typing_extensions import override
810

911
import httpx
1012

1113
from replicate.types.prediction_output import PredictionOutput
1214

1315
from .._utils import is_mapping, is_sequence
14-
from .._client import ReplicateClient, AsyncReplicateClient
16+
17+
# Use TYPE_CHECKING to avoid circular imports
18+
if TYPE_CHECKING:
19+
from .._client import ReplicateClient, AsyncReplicateClient
20+
21+
FileEncodingStrategy = Literal["base64", "url"]
22+
23+
24+
try:
25+
import numpy as np # type: ignore
26+
27+
HAS_NUMPY = True
28+
except ImportError:
29+
HAS_NUMPY = False # type: ignore
30+
31+
32+
# pylint: disable=too-many-return-statements
33+
def encode_json(
34+
obj: Any, # noqa: ANN401
35+
client: ReplicateClient,
36+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
37+
) -> Any: # noqa: ANN401
38+
"""
39+
Return a JSON-compatible version of the object.
40+
"""
41+
42+
if isinstance(obj, dict):
43+
return {
44+
key: encode_json(value, client, file_encoding_strategy)
45+
for key, value in obj.items() # type: ignore
46+
} # type: ignore
47+
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
48+
return [encode_json(value, client, file_encoding_strategy) for value in obj] # type: ignore
49+
if isinstance(obj, Path):
50+
with obj.open("rb") as file:
51+
return encode_json(file, client, file_encoding_strategy)
52+
if isinstance(obj, io.IOBase):
53+
if file_encoding_strategy == "base64":
54+
return base64_encode_file(obj)
55+
else:
56+
# todo: support files endpoint
57+
# return client.files.create(obj).urls["get"]
58+
raise NotImplementedError("File upload is not supported yet")
59+
if HAS_NUMPY:
60+
if isinstance(obj, np.integer): # type: ignore
61+
return int(obj)
62+
if isinstance(obj, np.floating): # type: ignore
63+
return float(obj)
64+
if isinstance(obj, np.ndarray): # type: ignore
65+
return obj.tolist()
66+
return obj
67+
68+
69+
async def async_encode_json(
70+
obj: Any, # noqa: ANN401
71+
client: AsyncReplicateClient,
72+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
73+
) -> Any: # noqa: ANN401
74+
"""
75+
Asynchronously return a JSON-compatible version of the object.
76+
"""
77+
78+
if isinstance(obj, dict):
79+
return {
80+
key: (await async_encode_json(value, client, file_encoding_strategy))
81+
for key, value in obj.items() # type: ignore
82+
} # type: ignore
83+
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
84+
return [
85+
(await async_encode_json(value, client, file_encoding_strategy))
86+
for value in obj # type: ignore
87+
]
88+
if isinstance(obj, Path):
89+
with obj.open("rb") as file:
90+
return await async_encode_json(file, client, file_encoding_strategy)
91+
if isinstance(obj, io.IOBase):
92+
if file_encoding_strategy == "base64":
93+
# TODO: This should ideally use an async based file reader path.
94+
return base64_encode_file(obj)
95+
else:
96+
# todo: support files endpoint
97+
# return (await client.files.async_create(obj)).urls["get"]
98+
raise NotImplementedError("File upload is not supported yet")
99+
if HAS_NUMPY:
100+
if isinstance(obj, np.integer): # type: ignore
101+
return int(obj)
102+
if isinstance(obj, np.floating): # type: ignore
103+
return float(obj)
104+
if isinstance(obj, np.ndarray): # type: ignore
105+
return obj.tolist()
106+
return obj
15107

16108

17109
def base64_encode_file(file: io.IOBase) -> str:
@@ -126,7 +218,7 @@ def __repr__(self) -> str:
126218
return f'{self.__class__.__name__}("{self.url}")'
127219

128220

129-
def transform_output(value: PredictionOutput, client: ReplicateClient | AsyncReplicateClient) -> Any:
221+
def transform_output(value: PredictionOutput, client: "ReplicateClient | AsyncReplicateClient") -> Any:
130222
"""
131223
Transform the output of a prediction to a `FileOutput` object if it's a URL.
132224
"""
@@ -137,9 +229,11 @@ def transform(obj: Any) -> Any:
137229
elif is_sequence(obj) and not isinstance(obj, str):
138230
return [transform(item) for item in obj]
139231
elif isinstance(obj, str) and (obj.startswith("https:") or obj.startswith("data:")):
140-
if isinstance(client, AsyncReplicateClient):
141-
return AsyncFileOutput(obj, client)
142-
return FileOutput(obj, client)
232+
# Check if the client is async by looking for async in the class name
233+
# we're doing this to avoid circular imports
234+
if "Async" in client.__class__.__name__:
235+
return AsyncFileOutput(obj, client) # type: ignore
236+
return FileOutput(obj, client) # type: ignore
143237
return obj
144238

145239
return transform(value)

src/replicate/lib/_predictions.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Dict, Union, Iterable, Optional
44
from typing_extensions import Unpack
55

6+
from replicate.lib._files import FileEncodingStrategy
67
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
78

89
from ..types import PredictionOutput, PredictionCreateParams
@@ -22,6 +23,7 @@ def run(
2223
*,
2324
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
2425
use_file_output: Optional[bool] = True,
26+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
2527
**params: Unpack[PredictionCreateParamsWithoutVersion],
2628
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
2729
"""
@@ -75,15 +77,17 @@ def run(
7577
if version_id is not None:
7678
# Create prediction with the specific version ID
7779
params_with_version: PredictionCreateParams = {**params, "version": version_id}
78-
prediction = client.predictions.create(**params_with_version)
80+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
7981
elif owner and name:
8082
# Create prediction via models resource with owner/name
81-
prediction = client.models.predictions.create(model_owner=owner, model_name=name, **params)
83+
prediction = client.models.predictions.create(
84+
file_encoding_strategy=file_encoding_strategy, model_owner=owner, model_name=name, **params
85+
)
8286
else:
8387
# If ref is a string but doesn't match expected patterns
8488
if isinstance(ref, str):
8589
params_with_version = {**params, "version": ref}
86-
prediction = client.predictions.create(**params_with_version)
90+
prediction = client.predictions.create(file_encoding_strategy=file_encoding_strategy, **params_with_version)
8791
else:
8892
raise ValueError(
8993
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
@@ -110,7 +114,7 @@ def run(
110114
# TODO: Return an iterator for completed output if the model has an output iterator array type.
111115

112116
if use_file_output:
113-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
117+
return transform_output(prediction.output, client) # type: ignore[no-any-return]
114118

115119
return prediction.output
116120

@@ -119,6 +123,7 @@ async def async_run(
119123
client: "AsyncReplicateClient",
120124
ref: Union[Model, Version, ModelVersionIdentifier, str],
121125
*,
126+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
122127
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
123128
use_file_output: Optional[bool] = True,
124129
**params: Unpack[PredictionCreateParamsWithoutVersion],
@@ -174,15 +179,21 @@ async def async_run(
174179
if version_id is not None:
175180
# Create prediction with the specific version ID
176181
params_with_version: PredictionCreateParams = {**params, "version": version_id}
177-
prediction = await client.predictions.create(**params_with_version)
182+
prediction = await client.predictions.create(
183+
file_encoding_strategy=file_encoding_strategy, **params_with_version
184+
)
178185
elif owner and name:
179186
# Create prediction via models resource with owner/name
180-
prediction = await client.models.predictions.create(model_owner=owner, model_name=name, **params)
187+
prediction = await client.models.predictions.create(
188+
model_owner=owner, model_name=name, file_encoding_strategy=file_encoding_strategy, **params
189+
)
181190
else:
182191
# If ref is a string but doesn't match expected patterns
183192
if isinstance(ref, str):
184193
params_with_version = {**params, "version": ref}
185-
prediction = await client.predictions.create(**params_with_version)
194+
prediction = await client.predictions.create(
195+
file_encoding_strategy=file_encoding_strategy, **params_with_version
196+
)
186197
else:
187198
raise ValueError(
188199
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
@@ -209,6 +220,6 @@ async def async_run(
209220
# TODO: Return an iterator for completed output if the model has an output iterator array type.
210221

211222
if use_file_output:
212-
return transform_output(prediction.output, client) # type: ignore[no-any-return]
223+
return transform_output(prediction.output, client) # type: ignore[no-any-return]
213224

214225
return prediction.output

src/replicate/resources/models/predictions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from __future__ import annotations
44

5-
from typing import List
5+
from typing import List, Optional
66
from typing_extensions import Literal
77

88
import httpx
99

10+
from replicate.lib._files import FileEncodingStrategy, encode_json, async_encode_json
11+
1012
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
1113
from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
1214
from ..._compat import cached_property
@@ -54,6 +56,7 @@ def create(
5456
webhook: str | NotGiven = NOT_GIVEN,
5557
webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
5658
prefer: str | NotGiven = NOT_GIVEN,
59+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
5760
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
5861
# The extra values given here take precedence over values defined on the client or passed to this method.
5962
extra_headers: Headers | None = None,
@@ -171,7 +174,7 @@ def create(
171174
f"/models/{model_owner}/{model_name}/predictions",
172175
body=maybe_transform(
173176
{
174-
"input": input,
177+
"input": encode_json(input, self._client, file_encoding_strategy=file_encoding_strategy),
175178
"stream": stream,
176179
"webhook": webhook,
177180
"webhook_events_filter": webhook_events_filter,
@@ -215,6 +218,7 @@ async def create(
215218
webhook: str | NotGiven = NOT_GIVEN,
216219
webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
217220
prefer: str | NotGiven = NOT_GIVEN,
221+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
218222
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
219223
# The extra values given here take precedence over values defined on the client or passed to this method.
220224
extra_headers: Headers | None = None,
@@ -332,7 +336,9 @@ async def create(
332336
f"/models/{model_owner}/{model_name}/predictions",
333337
body=await async_maybe_transform(
334338
{
335-
"input": input,
339+
"input": await async_encode_json(
340+
input, self._client, file_encoding_strategy=file_encoding_strategy
341+
),
336342
"stream": stream,
337343
"webhook": webhook,
338344
"webhook_events_filter": webhook_events_filter,

src/replicate/resources/predictions.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22

33
from __future__ import annotations
44

5-
from typing import List, Union
5+
from typing import List, Union, Optional
66
from datetime import datetime
77
from typing_extensions import Literal
88

99
import httpx
1010

11+
from replicate.lib._files import FileEncodingStrategy, encode_json, async_encode_json
12+
1113
from ..types import prediction_list_params, prediction_create_params
1214
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
1315
from .._utils import maybe_transform, strip_not_given, async_maybe_transform
@@ -65,6 +67,7 @@ def create(
6567
webhook: str | NotGiven = NOT_GIVEN,
6668
webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
6769
prefer: str | NotGiven = NOT_GIVEN,
70+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
6871
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
6972
# The extra values given here take precedence over values defined on the client or passed to this method.
7073
extra_headers: Headers | None = None,
@@ -188,7 +191,7 @@ def create(
188191
"/predictions",
189192
body=maybe_transform(
190193
{
191-
"input": input,
194+
"input": encode_json(input, self._client, file_encoding_strategy=file_encoding_strategy),
192195
"version": version,
193196
"stream": stream,
194197
"webhook": webhook,
@@ -491,6 +494,7 @@ async def create(
491494
webhook: str | NotGiven = NOT_GIVEN,
492495
webhook_events_filter: List[Literal["start", "output", "logs", "completed"]] | NotGiven = NOT_GIVEN,
493496
prefer: str | NotGiven = NOT_GIVEN,
497+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
494498
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
495499
# The extra values given here take precedence over values defined on the client or passed to this method.
496500
extra_headers: Headers | None = None,
@@ -614,7 +618,9 @@ async def create(
614618
"/predictions",
615619
body=await async_maybe_transform(
616620
{
617-
"input": input,
621+
"input": await async_encode_json(
622+
input, self._client, file_encoding_strategy=file_encoding_strategy
623+
),
618624
"version": version,
619625
"stream": stream,
620626
"webhook": webhook,

src/replicate/types/prediction_output.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,10 @@
77

88
# todo: this shouldn't need to be custom code. We should update the spec to include the `Optional[List[str]]` type
99
PredictionOutput: TypeAlias = Union[
10-
Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[List[str]], Optional[str], Optional[float], Optional[bool]
10+
Optional[Dict[str, object]],
11+
Optional[List[Dict[str, object]]],
12+
Optional[List[str]],
13+
Optional[str],
14+
Optional[float],
15+
Optional[bool],
1116
]

0 commit comments

Comments
 (0)