Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
741df3a
feat(client): first pass at .run() helper
dgellow Apr 17, 2025
cfcf1a7
fix linter issues
dtmeadows Apr 17, 2025
36a65e6
Update run() return type to Any
dgellow Apr 23, 2025
38cdddc
Fix formatting
dgellow Apr 23, 2025
ffdaeda
Fix async run()
dgellow Apr 23, 2025
c6f76bb
Add run() tests
dgellow Apr 23, 2025
4a25c43
Add run() tests
dgellow Apr 23, 2025
85d1942
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 5, 2025
00723b4
add back separate models
dtmeadows May 5, 2025
46b1f8d
clean up tests
dtmeadows May 5, 2025
b969e71
add support for more ref types
dtmeadows May 5, 2025
7b1a1cc
forgot the file
dtmeadows May 5, 2025
c9f34c5
add support for use_file_output
dtmeadows May 5, 2025
1804f7c
temp fix to generated type
dtmeadows May 5, 2025
c80c7f3
fix up tests
dtmeadows May 5, 2025
82fe905
fix types and skip 2 remaining tests
dtmeadows May 5, 2025
eaf0a07
implement encode_json for input and thread through
dtmeadows May 5, 2025
6824251
adjust typing to just be Any for PredictionOutput
dtmeadows May 6, 2025
eadfb53
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 6, 2025
cf383c6
finish last bits of todos
dtmeadows May 6, 2025
24d30f4
updates to support async example
dtmeadows May 6, 2025
11e09d3
clean up testing a bit more
dtmeadows May 6, 2025
0257edf
move docs up into client
dtmeadows May 6, 2025
00970c0
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 7, 2025
c7216db
clean up helpers to match underlying api changes
dtmeadows May 7, 2025
80b28fc
clean up and use better import
dtmeadows May 7, 2025
a5c02d4
fixup!
dtmeadows May 7, 2025
aef2230
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 7, 2025
0ec2897
fix files uploading
dtmeadows May 7, 2025
2143731
add test for file upload
dtmeadows May 7, 2025
0beb9ef
Revert "add test for file upload"
dtmeadows May 7, 2025
c9c4b8c
Revert "fix files uploading"
dtmeadows May 7, 2025
ad4da10
fixup!
dtmeadows May 7, 2025
e8af3f1
Merge remote-tracking branch 'origin/next' into sam/run-helper
dtmeadows May 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/run_a_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import rich

from replicate import Replicate

client = Replicate()

outputs = client.run(
"black-forest-labs/flux-schnell",
input={"prompt": "astronaut riding a rocket like a horse"},
)
rich.print(outputs)
for index, output in enumerate(outputs):
with open(f"output_{index}.webp", "wb") as file:
file.write(output.read())
20 changes: 20 additions & 0 deletions examples/run_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio

from replicate import AsyncReplicate

client = AsyncReplicate()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename client to replicate here? Will that work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

(we can't change the from replicate import AsyncReplicate unfortunately though here, since the module client only returns the sync client)


# https://replicate.com/stability-ai/sdxl
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"]]


async def main() -> None:
# Create tasks with asyncio.gather directly
tasks = [client.run(model_version, input={"prompt": prompt}) for prompt in prompts]

results = await asyncio.gather(*tasks)
print(results)


asyncio.run(main())
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies = [
"anyio>=3.5.0, <5",
"distro>=1.7.0, <2",
"sniffio",
"asyncio>=3.4.3",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't quite right, asyncio is stdlib

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

]
requires-python = ">= 3.8"
classifiers = [
Expand Down
8 changes: 8 additions & 0 deletions src/replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ._version import __title__, __version__
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
from .lib._files import FileOutput as FileOutput, AsyncFileOutput as AsyncFileOutput
from ._exceptions import (
APIError,
ConflictError,
Expand All @@ -38,6 +39,7 @@
UnprocessableEntityError,
APIResponseValidationError,
)
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging

Expand Down Expand Up @@ -80,6 +82,11 @@
"DEFAULT_CONNECTION_LIMITS",
"DefaultHttpxClient",
"DefaultAsyncHttpxClient",
"FileOutput",
"AsyncFileOutput",
"Model",
"Version",
"ModelVersionIdentifier",
]

_setup_logging()
Expand Down Expand Up @@ -230,6 +237,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]


from ._module_client import (
run as run,
models as models,
account as account,
hardware as hardware,
Expand Down
58 changes: 56 additions & 2 deletions src/replicate/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing_extensions import Self, override
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
from typing_extensions import Self, Unpack, override

import httpx

from replicate.lib._files import FileEncodingStrategy
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -164,6 +168,10 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
def with_streaming_response(self) -> ReplicateWithStreamedResponse:
return ReplicateWithStreamedResponse(self)

@cached_property
def poll_interval(self) -> float:
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

@property
@override
def qs(self) -> Querystring:
Expand All @@ -184,6 +192,27 @@ def default_headers(self) -> dict[str, str | Omit]:
**self._custom_headers,
}

def run(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
use_file_output: bool = True,
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Any:
"""Run a model and wait for its output."""
from .lib._predictions import run

return run(
self,
ref,
wait=wait,
use_file_output=use_file_output,
file_encoding_strategy=file_encoding_strategy,
**params,
)

def copy(
self,
*,
Expand Down Expand Up @@ -380,6 +409,10 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse:
return AsyncReplicateWithStreamedResponse(self)

@cached_property
def poll_interval(self) -> float:
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))

@property
@override
def qs(self) -> Querystring:
Expand All @@ -400,6 +433,27 @@ def default_headers(self) -> dict[str, str | Omit]:
**self._custom_headers,
}

async def run(
self,
ref: Union[Model, Version, ModelVersionIdentifier, str],
*,
use_file_output: bool = True,
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
**params: Unpack[PredictionCreateParamsWithoutVersion],
) -> Any:
"""Run a model and wait for its output."""
from .lib._predictions import async_run

return await async_run(
self,
ref,
wait=wait,
use_file_output=use_file_output,
file_encoding_strategy=file_encoding_strategy,
**params,
)

def copy(
self,
*,
Expand Down
13 changes: 13 additions & 0 deletions src/replicate/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import httpx

from replicate.types.prediction import Prediction

__all__ = [
"BadRequestError",
"AuthenticationError",
Expand All @@ -15,6 +17,7 @@
"UnprocessableEntityError",
"RateLimitError",
"InternalServerError",
"ModelError",
]


Expand Down Expand Up @@ -106,3 +109,13 @@ class RateLimitError(APIStatusError):

class InternalServerError(APIStatusError):
pass


class ModelError(ReplicateError):
"""An error from user's code in a model."""

prediction: Prediction

def __init__(self, prediction: Prediction) -> None:
self.prediction = prediction
super().__init__(prediction.error)
16 changes: 15 additions & 1 deletion src/replicate/_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing_extensions import override
from typing_extensions import cast, override

if TYPE_CHECKING:
from .resources.account import AccountResource
Expand Down Expand Up @@ -67,6 +67,20 @@ def __load__(self) -> PredictionsResource:
return _load_client().predictions


if TYPE_CHECKING:
from ._client import Replicate

# get the type checker to infer the run symbol to the same type
# as the method on the client so we don't have to define it twice
__client: Replicate = cast(Replicate, {})
run = __client.run
else:

def _run(*args, **kwargs):
return _load_client().run(*args, **kwargs)

run = _run

models: ModelsResource = ModelsResourceProxy().__as_proxied__()
account: AccountResource = AccountResourceProxy().__as_proxied__()
hardware: HardwareResource = HardwareResourceProxy().__as_proxied__()
Expand Down
Loading