Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
741df3a
feat(client): first pass at .run() helper
dgellow Apr 17, 2025
cfcf1a7
fix linter issues
dtmeadows Apr 17, 2025
36a65e6
Update run() return type to Any
dgellow Apr 23, 2025
38cdddc
Fix formatting
dgellow Apr 23, 2025
ffdaeda
Fix async run()
dgellow Apr 23, 2025
c6f76bb
Add run() tests
dgellow Apr 23, 2025
4a25c43
Add run() tests
dgellow Apr 23, 2025
85d1942
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 5, 2025
00723b4
add back separate models
dtmeadows May 5, 2025
46b1f8d
clean up tests
dtmeadows May 5, 2025
b969e71
add support for more ref types
dtmeadows May 5, 2025
7b1a1cc
forgot the file
dtmeadows May 5, 2025
c9f34c5
add support for use_file_output
dtmeadows May 5, 2025
1804f7c
temp fix to generated type
dtmeadows May 5, 2025
c80c7f3
fix up tests
dtmeadows May 5, 2025
82fe905
fix types and skip 2 remaining tests
dtmeadows May 5, 2025
eaf0a07
implement encode_json for input and thread through
dtmeadows May 5, 2025
6824251
adjust typing to just be Any for PredictionOutput
dtmeadows May 6, 2025
eadfb53
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 6, 2025
cf383c6
finish last bits of todos
dtmeadows May 6, 2025
24d30f4
updates to support async example
dtmeadows May 6, 2025
11e09d3
clean up testing a bit more
dtmeadows May 6, 2025
0257edf
move docs up into client
dtmeadows May 6, 2025
00970c0
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 7, 2025
c7216db
clean up helpers to match underlying api changes
dtmeadows May 7, 2025
80b28fc
clean up and use better import
dtmeadows May 7, 2025
a5c02d4
fixup!
dtmeadows May 7, 2025
aef2230
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 7, 2025
0ec2897
fix files uploading
dtmeadows May 7, 2025
2143731
add test for file upload
dtmeadows May 7, 2025
0beb9ef
Revert "add test for file upload"
dtmeadows May 7, 2025
c9c4b8c
Revert "fix files uploading"
dtmeadows May 7, 2025
ad4da10
fixup!
dtmeadows May 7, 2025
e8af3f1
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/run_a_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import rich

import replicate

outputs = replicate.run(
"black-forest-labs/flux-schnell",
input={"prompt": "astronaut riding a rocket like a horse"},
)
rich.print(outputs)
for index, output in enumerate(outputs):
with open(f"output_{index}.webp", "wb") as file:
file.write(output.read())
20 changes: 20 additions & 0 deletions examples/run_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from replicate import AsyncReplicate

replicate = AsyncReplicate()

# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"]]


async def main() -> None:
# Create tasks with asyncio.gather directly
tasks = [replicate.run(model_version, input={"prompt": prompt}) for prompt in prompts]

results = await asyncio.gather(*tasks)
print(results)


asyncio.run(main())
8 changes: 8 additions & 0 deletions src/replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from .lib._files import FileOutput as FileOutput, AsyncFileOutput as AsyncFileOutput
from ._exceptions import (
APIError,
ConflictError,
Expand All @@ -38,6 +39,7 @@
UnprocessableEntityError,
APIResponseValidationError,
)
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging

Expand Down Expand Up @@ -80,6 +82,11 @@
"DEFAULT_CONNECTION_LIMITS",
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"FileOutput",
"AsyncFileOutput",
"Model",
"Version",
"ModelVersionIdentifier",
]

_setup_logging()
Expand Down Expand Up @@ -230,6 +237,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]


from ._module_client import (
run as run,
files as files,
models as models,
account as account,
Expand Down
112 changes: 110 additions & 2 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing_extensions import Self, override
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
from typing_extensions import Self, Unpack, override

import httpx

from replicate.lib._files import FileEncodingStrategy
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -171,6 +175,10 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
def with_streaming_response(self) -> ReplicateWithStreamedResponse:
return ReplicateWithStreamedResponse(self)

@cached_property
def poll_interval(self) -> float:
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

@property
@override
def qs(self) -> Querystring:
Expand All @@ -191,6 +199,54 @@ def default_headers(self) -> dict[str, str | Omit]:
**self._custom_headers,
}

def run(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
use_file_output: bool = True,
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Any:
"""
Run a model prediction.

Args:
ref: Reference to the model or version to run. Can be:
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
- A string with owner/name format (e.g. "replicate/hello-world")
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
- A Model instance with owner and name attributes
- A Version instance with id attribute
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
use_file_output: If True (default), convert output URLs to FileOutput objects
wait: If True (default), wait for the prediction to complete. If False, return immediately.
If an integer, wait up to that many seconds.
**params: Additional parameters to pass to the prediction creation endpoint including
the required "input" dictionary with model-specific parameters

Returns:
The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
the model returns.

Raises:
ModelError: If the model run fails
ValueError: If the reference format is invalid
TypeError: If both wait and prefer parameters are provided
"""
from .lib._predictions import run

return run(
self,
ref,
wait=wait,
use_file_output=use_file_output,
file_encoding_strategy=file_encoding_strategy,
**params,
)

def copy(
self,
*,
Expand Down Expand Up @@ -393,6 +449,10 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse:
return AsyncReplicateWithStreamedResponse(self)

@cached_property
def poll_interval(self) -> float:
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

@property
@override
def qs(self) -> Querystring:
Expand All @@ -413,6 +473,54 @@ def default_headers(self) -> dict[str, str | Omit]:
**self._custom_headers,
}

async def run(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
use_file_output: bool = True,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Any:
"""
Run a model prediction asynchronously.

Args:
ref: Reference to the model or version to run. Can be:
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
- A string with owner/name format (e.g. "replicate/hello-world")
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
- A Model instance with owner and name attributes
- A Version instance with id attribute
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
wait: If True (default), wait for the prediction to complete. If False, return immediately.
If an integer, wait up to that many seconds.
**params: Additional parameters to pass to the prediction creation endpoint including
the required "input" dictionary with model-specific parameters

Returns:
The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
the model returns.

Raises:
ModelError: If the model run fails
ValueError: If the reference format is invalid
TypeError: If both wait and prefer parameters are provided
"""
from .lib._predictions import async_run

return await async_run(
self,
ref,
wait=wait,
use_file_output=use_file_output,
file_encoding_strategy=file_encoding_strategy,
**params,
)

def copy(
self,
*,
Expand Down
13 changes: 13 additions & 0 deletions src/replicate/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import httpx

from replicate.types.prediction import Prediction

__all__ = [
"BadRequestError",
"AuthenticationError",
Expand All @@ -15,6 +17,7 @@
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"ModelError",
]


Expand Down Expand Up @@ -106,3 +109,13 @@ class RateLimitError(APIStatusError):

class InternalServerError(APIStatusError):
pass


class ModelError(ReplicateError):
"""An error from user's code in a model."""

prediction: Prediction

def __init__(self, prediction: Prediction) -> None:
self.prediction = prediction
super().__init__(prediction.error)
16 changes: 15 additions & 1 deletion src/replicate/_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing_extensions import override
from typing_extensions import cast, override

if TYPE_CHECKING:
from .resources.files import FilesResource
Expand Down Expand Up @@ -74,6 +74,20 @@ def __load__(self) -> PredictionsResource:
return _load_client().predictions


if TYPE_CHECKING:
from ._client import Replicate

# get the type checker to infer the run symbol to the same type
# as the method on the client so we don't have to define it twice
__client: Replicate = cast(Replicate, {})
run = __client.run
else:

def _run(*args, **kwargs):
return _load_client().run(*args, **kwargs)

run = _run

files: FilesResource = FilesResourceProxy().__as_proxied__()
models: ModelsResource = ModelsResourceProxy().__as_proxied__()
account: AccountResource = AccountResourceProxy().__as_proxied__()
Expand Down
Loading