Skip to content

Commit 686d4d7

Browse files
authored
Merge pull request #5 from replicate/sam/run-helper
feat(client): first pass at .run() helper
2 parents 0e4a103 + e8af3f1 commit 686d4d7

File tree

18 files changed

+1703
-28
lines changed

18 files changed

+1703
-28
lines changed

examples/run_a_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import rich
2+
3+
import replicate
4+
5+
outputs = replicate.run(
6+
"black-forest-labs/flux-schnell",
7+
input={"prompt": "astronaut riding a rocket like a horse"},
8+
)
9+
rich.print(outputs)
10+
for index, output in enumerate(outputs):
11+
with open(f"output_{index}.webp", "wb") as file:
12+
file.write(output.read())

examples/run_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import asyncio
2+
3+
from replicate import AsyncReplicate
4+
5+
replicate = AsyncReplicate()
6+
7+
# https://replicate.com/stability-ai/sdxl
8+
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
9+
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"]]
10+
11+
12+
async def main() -> None:
13+
# Create tasks with asyncio.gather directly
14+
tasks = [replicate.run(model_version, input={"prompt": prompt}) for prompt in prompts]
15+
16+
results = await asyncio.gather(*tasks)
17+
print(results)
18+
19+
20+
asyncio.run(main())

src/replicate/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ._version import __title__, __version__
2323
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
2424
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
25+
from .lib._files import FileOutput as FileOutput, AsyncFileOutput as AsyncFileOutput
2526
from ._exceptions import (
2627
APIError,
2728
ConflictError,
@@ -38,6 +39,7 @@
3839
UnprocessableEntityError,
3940
APIResponseValidationError,
4041
)
42+
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
4143
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
4244
from ._utils._logs import setup_logging as _setup_logging
4345

@@ -80,6 +82,11 @@
8082
"DEFAULT_CONNECTION_LIMITS",
8183
"DefaultHttpxClient",
8284
"DefaultAsyncHttpxClient",
85+
"FileOutput",
86+
"AsyncFileOutput",
87+
"Model",
88+
"Version",
89+
"ModelVersionIdentifier",
8390
]
8491

8592
_setup_logging()
@@ -230,6 +237,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
230237

231238

232239
from ._module_client import (
240+
run as run,
233241
files as files,
234242
models as models,
235243
account as account,

src/replicate/_client.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping
7-
from typing_extensions import Self, override
6+
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
7+
from typing_extensions import Self, Unpack, override
88

99
import httpx
1010

11+
from replicate.lib._files import FileEncodingStrategy
12+
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
13+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
14+
1115
from . import _exceptions
1216
from ._qs import Querystring
1317
from ._types import (
@@ -171,6 +175,10 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
171175
def with_streaming_response(self) -> ReplicateWithStreamedResponse:
172176
return ReplicateWithStreamedResponse(self)
173177

178+
@cached_property
179+
def poll_interval(self) -> float:
180+
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
181+
174182
@property
175183
@override
176184
def qs(self) -> Querystring:
@@ -191,6 +199,54 @@ def default_headers(self) -> dict[str, str | Omit]:
191199
**self._custom_headers,
192200
}
193201

202+
def run(
203+
self,
204+
ref: Union[Model, Version, ModelVersionIdentifier, str],
205+
*,
206+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
207+
use_file_output: bool = True,
208+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
209+
**params: Unpack[PredictionCreateParamsWithoutVersion],
210+
) -> Any:
211+
"""
212+
Run a model prediction.
213+
214+
Args:
215+
ref: Reference to the model or version to run. Can be:
216+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
217+
- A string with owner/name format (e.g. "replicate/hello-world")
218+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
219+
- A Model instance with owner and name attributes
220+
- A Version instance with id attribute
221+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
222+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
223+
use_file_output: If True (default), convert output URLs to FileOutput objects
224+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
225+
If an integer, wait up to that many seconds.
226+
**params: Additional parameters to pass to the prediction creation endpoint including
227+
the required "input" dictionary with model-specific parameters
228+
229+
Returns:
230+
The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
231+
a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
232+
the model returns.
233+
234+
Raises:
235+
ModelError: If the model run fails
236+
ValueError: If the reference format is invalid
237+
TypeError: If both wait and prefer parameters are provided
238+
"""
239+
from .lib._predictions import run
240+
241+
return run(
242+
self,
243+
ref,
244+
wait=wait,
245+
use_file_output=use_file_output,
246+
file_encoding_strategy=file_encoding_strategy,
247+
**params,
248+
)
249+
194250
def copy(
195251
self,
196252
*,
@@ -393,6 +449,10 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
393449
def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse:
394450
return AsyncReplicateWithStreamedResponse(self)
395451

452+
@cached_property
453+
def poll_interval(self) -> float:
454+
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
455+
396456
@property
397457
@override
398458
def qs(self) -> Querystring:
@@ -413,6 +473,54 @@ def default_headers(self) -> dict[str, str | Omit]:
413473
**self._custom_headers,
414474
}
415475

476+
async def run(
477+
self,
478+
ref: Union[Model, Version, ModelVersionIdentifier, str],
479+
*,
480+
use_file_output: bool = True,
481+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
482+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
483+
**params: Unpack[PredictionCreateParamsWithoutVersion],
484+
) -> Any:
485+
"""
486+
Run a model prediction asynchronously.
487+
488+
Args:
489+
ref: Reference to the model or version to run. Can be:
490+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
491+
- A string with owner/name format (e.g. "replicate/hello-world")
492+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
493+
- A Model instance with owner and name attributes
494+
- A Version instance with id attribute
495+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
496+
use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
497+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
498+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
499+
If an integer, wait up to that many seconds.
500+
**params: Additional parameters to pass to the prediction creation endpoint including
501+
the required "input" dictionary with model-specific parameters
502+
503+
Returns:
504+
The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
505+
a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
506+
the model returns.
507+
508+
Raises:
509+
ModelError: If the model run fails
510+
ValueError: If the reference format is invalid
511+
TypeError: If both wait and prefer parameters are provided
512+
"""
513+
from .lib._predictions import async_run
514+
515+
return await async_run(
516+
self,
517+
ref,
518+
wait=wait,
519+
use_file_output=use_file_output,
520+
file_encoding_strategy=file_encoding_strategy,
521+
**params,
522+
)
523+
416524
def copy(
417525
self,
418526
*,

src/replicate/_exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import httpx
88

9+
from replicate.types.prediction import Prediction
10+
911
__all__ = [
1012
"BadRequestError",
1113
"AuthenticationError",
@@ -15,6 +17,7 @@
1517
"UnprocessableEntityError",
1618
"RateLimitError",
1719
"InternalServerError",
20+
"ModelError",
1821
]
1922

2023

@@ -106,3 +109,13 @@ class RateLimitError(APIStatusError):
106109

107110
class InternalServerError(APIStatusError):
108111
pass
112+
113+
114+
class ModelError(ReplicateError):
115+
"""An error from user's code in a model."""
116+
117+
prediction: Prediction
118+
119+
def __init__(self, prediction: Prediction) -> None:
120+
self.prediction = prediction
121+
super().__init__(prediction.error)

src/replicate/_module_client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from typing import TYPE_CHECKING
6-
from typing_extensions import override
6+
from typing_extensions import cast, override
77

88
if TYPE_CHECKING:
99
from .resources.files import FilesResource
@@ -74,6 +74,20 @@ def __load__(self) -> PredictionsResource:
7474
return _load_client().predictions
7575

7676

77+
if TYPE_CHECKING:
78+
from ._client import Replicate
79+
80+
# get the type checker to infer the run symbol to the same type
81+
# as the method on the client so we don't have to define it twice
82+
__client: Replicate = cast(Replicate, {})
83+
run = __client.run
84+
else:
85+
86+
def _run(*args, **kwargs):
87+
return _load_client().run(*args, **kwargs)
88+
89+
run = _run
90+
7791
files: FilesResource = FilesResourceProxy().__as_proxied__()
7892
models: ModelsResource = ModelsResourceProxy().__as_proxied__()
7993
account: AccountResource = AccountResourceProxy().__as_proxied__()

0 commit comments

Comments
 (0)