Skip to content

Commit e227b9a

Browse files
committed
feat(client): first pass at .run() helper
1 parent a812f1b commit e227b9a

File tree

9 files changed

+357
-7
lines changed

9 files changed

+357
-7
lines changed

examples/demo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Iterable, cast
2+
3+
import rich
4+
5+
from replicate import FileOutput, ReplicateClient
6+
7+
client = ReplicateClient()
8+
9+
outputs = client.run(
10+
"black-forest-labs/flux-schnell",
11+
input={"prompt": "astronaut riding a rocket like a horse"},
12+
)
13+
rich.print(outputs)
14+
for index, output in enumerate(cast(Iterable[FileOutput], outputs)):
15+
with open(f"output_{index}.webp", "wb") as file:
16+
file.write(output.read())

src/replicate/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ._version import __title__, __version__
2323
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
2424
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
25+
from .lib._files import FileOutput as FileOutput, AsyncFileOutput as AsyncFileOutput
2526
from ._exceptions import (
2627
APIError,
2728
ConflictError,
@@ -80,6 +81,8 @@
8081
"DEFAULT_CONNECTION_LIMITS",
8182
"DefaultHttpxClient",
8283
"DefaultAsyncHttpxClient",
84+
"FileOutput",
85+
"AsyncFileOutput",
8386
]
8487

8588
_setup_logging()
@@ -230,6 +233,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
230233

231234

232235
from ._module_client import (
236+
run as run,
233237
models as models,
234238
accounts as accounts,
235239
hardware as hardware,

src/replicate/_client.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import Any, Union, Mapping
7-
from typing_extensions import Self, override
6+
from typing import TYPE_CHECKING, Any, Dict, Union, Mapping, Iterable
7+
from typing_extensions import Self, Unpack, override
88

99
import httpx
1010

1111
from . import _exceptions
1212
from ._qs import Querystring
13+
from .types import PredictionOutput, PredictionCreateParams
1314
from ._types import (
1415
NOT_GIVEN,
1516
Omit,
@@ -36,6 +37,9 @@
3637
from .resources.webhooks import webhooks
3738
from .resources.deployments import deployments
3839

40+
if TYPE_CHECKING:
41+
from .lib._files import FileOutput
42+
3943
__all__ = [
4044
"Timeout",
4145
"Transport",
@@ -124,6 +128,19 @@ def __init__(
124128
self.webhooks = webhooks.WebhooksResource(self)
125129
self.with_raw_response = ReplicateClientWithRawResponse(self)
126130
self.with_streaming_response = ReplicateClientWithStreamedResponse(self)
131+
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
132+
133+
def run(
134+
self,
135+
ref: str,
136+
*,
137+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
138+
**params: Unpack[PredictionCreateParams],
139+
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
140+
"""Run a model and wait for its output."""
141+
from .lib._predictions import run
142+
143+
return run(self, ref, wait=wait, **params)
127144

128145
@property
129146
@override
@@ -306,6 +323,19 @@ def __init__(
306323
self.webhooks = webhooks.AsyncWebhooksResource(self)
307324
self.with_raw_response = AsyncReplicateClientWithRawResponse(self)
308325
self.with_streaming_response = AsyncReplicateClientWithStreamedResponse(self)
326+
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
327+
328+
async def run(
329+
self,
330+
ref: str,
331+
*,
332+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
333+
**params: Unpack[PredictionCreateParams],
334+
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
335+
"""Run a model and wait for its output."""
336+
from .lib._predictions import async_run
337+
338+
return await async_run(self, ref, wait=wait, **params)
309339

310340
@property
311341
@override

src/replicate/_exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import httpx
88

9+
from replicate.types.prediction import Prediction
10+
911
__all__ = [
1012
"BadRequestError",
1113
"AuthenticationError",
@@ -15,6 +17,7 @@
1517
"UnprocessableEntityError",
1618
"RateLimitError",
1719
"InternalServerError",
20+
"ModelError",
1821
]
1922

2023

@@ -106,3 +109,13 @@ class RateLimitError(APIStatusError):
106109

107110
class InternalServerError(APIStatusError):
108111
pass
112+
113+
114+
class ModelError(ReplicateClientError):
115+
"""An error from user's code in a model."""
116+
117+
prediction: Prediction
118+
119+
def __init__(self, prediction: Prediction) -> None:
120+
self.prediction = prediction
121+
super().__init__(prediction.error)

src/replicate/_module_client.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
22

3-
from typing_extensions import override
3+
from typing import TYPE_CHECKING
4+
from typing_extensions import cast, override
45

56
from . import resources, _load_client
67
from ._utils import LazyProxy
@@ -62,3 +63,17 @@ def __load__(self) -> resources.PredictionsResource:
6263
collections: resources.CollectionsResource = CollectionsResourceProxy().__as_proxied__()
6364
deployments: resources.DeploymentsResource = DeploymentsResourceProxy().__as_proxied__()
6465
predictions: resources.PredictionsResource = PredictionsResourceProxy().__as_proxied__()
66+
67+
if TYPE_CHECKING:
68+
from ._client import ReplicateClient
69+
70+
# get the type checker to infer the run symbol to the same type
71+
# as the method on the client so we don't have to define it twice
72+
__client: ReplicateClient = cast(ReplicateClient, {})
73+
run = __client.run
74+
else:
75+
76+
def _run(*args, **kwargs):
77+
return _load_client().run(*args, **kwargs)
78+
79+
run = _run

src/replicate/lib/_files.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
from __future__ import annotations
2+
3+
import io
4+
import base64
5+
import mimetypes
6+
from typing import Any, Iterator, AsyncIterator
7+
from typing_extensions import override
8+
9+
import httpx
10+
11+
from .._utils import is_mapping, is_sequence
12+
from .._client import ReplicateClient, AsyncReplicateClient
13+
14+
15+
def base64_encode_file(file: io.IOBase) -> str:
16+
"""
17+
Base64 encode a file.
18+
19+
Args:
20+
file: A file handle to upload.
21+
Returns:
22+
str: A base64-encoded data URI.
23+
"""
24+
25+
file.seek(0)
26+
body = file.read()
27+
28+
# Ensure the file handle is in bytes
29+
body = body.encode("utf-8") if isinstance(body, str) else body
30+
encoded_body = base64.b64encode(body).decode("utf-8")
31+
32+
mime_type = mimetypes.guess_type(getattr(file, "name", ""))[0] or "application/octet-stream"
33+
return f"data:{mime_type};base64,{encoded_body}"
34+
35+
36+
class FileOutput(httpx.SyncByteStream):
37+
"""
38+
An object that can be used to read the contents of an output file
39+
created by running a Replicate model.
40+
"""
41+
42+
url: str
43+
"""
44+
The file URL.
45+
"""
46+
47+
_client: ReplicateClient
48+
49+
def __init__(self, url: str, client: ReplicateClient) -> None:
50+
self.url = url
51+
self._client = client
52+
53+
def read(self) -> bytes:
54+
if self.url.startswith("data:"):
55+
_, encoded = self.url.split(",", 1)
56+
return base64.b64decode(encoded)
57+
58+
with self._client._client.stream("GET", self.url) as response:
59+
response.raise_for_status()
60+
return response.read()
61+
62+
@override
63+
def __iter__(self) -> Iterator[bytes]:
64+
if self.url.startswith("data:"):
65+
yield self.read()
66+
return
67+
68+
with self._client._client.stream("GET", self.url) as response:
69+
response.raise_for_status()
70+
yield from response.iter_bytes()
71+
72+
@override
73+
def __str__(self) -> str:
74+
return self.url
75+
76+
@override
77+
def __repr__(self) -> str:
78+
return f'{self.__class__.__name__}("{self.url}")'
79+
80+
81+
class AsyncFileOutput(httpx.AsyncByteStream):
82+
"""
83+
An object that can be used to read the contents of an output file
84+
created by running a Replicate model.
85+
"""
86+
87+
url: str
88+
"""
89+
The file URL.
90+
"""
91+
92+
_client: AsyncReplicateClient
93+
94+
def __init__(self, url: str, client: AsyncReplicateClient) -> None:
95+
self.url = url
96+
self._client = client
97+
98+
async def read(self) -> bytes:
99+
if self.url.startswith("data:"):
100+
_, encoded = self.url.split(",", 1)
101+
return base64.b64decode(encoded)
102+
103+
async with self._client._client.stream("GET", self.url) as response:
104+
response.raise_for_status()
105+
return await response.aread()
106+
107+
@override
108+
async def __aiter__(self) -> AsyncIterator[bytes]:
109+
if self.url.startswith("data:"):
110+
yield await self.read()
111+
return
112+
113+
async with self._client._client.stream("GET", self.url) as response:
114+
response.raise_for_status()
115+
async for chunk in response.aiter_bytes():
116+
yield chunk
117+
118+
@override
119+
def __str__(self) -> str:
120+
return self.url
121+
122+
@override
123+
def __repr__(self) -> str:
124+
return f'{self.__class__.__name__}("{self.url}")'
125+
126+
127+
def transform_output(value: Any, client: ReplicateClient | AsyncReplicateClient) -> Any:
128+
"""
129+
Transform the output of a prediction to a `FileOutput` object if it's a URL.
130+
"""
131+
132+
def transform(obj: Any) -> Any:
133+
if is_mapping(obj):
134+
return {k: transform(v) for k, v in obj.items()}
135+
elif is_sequence(obj) and not isinstance(obj, str):
136+
return [transform(item) for item in obj]
137+
elif isinstance(obj, str) and (obj.startswith("https:") or obj.startswith("data:")):
138+
if isinstance(client, AsyncReplicateClient):
139+
return AsyncFileOutput(obj, client)
140+
return FileOutput(obj, client)
141+
return obj
142+
143+
return transform(value)

0 commit comments

Comments
 (0)