Skip to content

Commit 6ed3cf7

Browse files
authored
Add async support (#193)
This PR adds support for async operations to the Replicate client. Namespace operations like `predictions.list` and `models.create` now have async variants with the `async_` prefix (`predictions.async_list` and `models.async_create`). Here's an example of what that looks like in practice: ```python import replicate model = await replicate.models.async_get("stability-ai/sdxl") input = { "prompt": "A chariot pulled by a team of rainbow unicorns, driven by an astronaut, dramatic lighting", } output = await replicate.async_run(f"stability-ai/sdxl:{model.latest_version.id}", input) ``` <details> <summary>Output</summary> <img src="https://github.com/replicate/replicate-python/assets/7659/6927f8b4-5f92-495d-a87c-135f31aa1847"/> </details> One of the most common questions I hear is how to run a bunch of predictions in parallel. The async functionality provided by this PR makes it really straightforward: ```python import asyncio import replicate # https://replicate.com/stability-ai/sdxl model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b" prompts = [ f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"] ] async with asyncio.TaskGroup() as tg: tasks = [ tg.create_task(replicate.async_run(model_version, input={"prompt": prompt})) for prompt in prompts ] results = await asyncio.gather(*tasks) print(results) ``` Under the hood, `Client` manages an `httpx.Client` and an `httpx.AsyncClient`, which handle calls to `_request` and `_async_request`, respectively. Both are created lazily, so API consumers using only sync or only async functionality won't be affected by functionality they aren't using. Implementation-wise, sync and async variants have separate code paths. This creates nontrivial amounts of duplication, but its benefits to clarity and performance justify those costs. For instance, it'd have been nice if the sync variant were implemented as a blocking call to the async variant, but that would require starting an event loop, which has additional overhead and causes problems if done within an existing event loop. Alternative to #76 Resolves #145 Resolves #107 Resolves #74 --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent db51ee0 commit 6ed3cf7

22 files changed

+9672
-553
lines changed

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,32 @@ Some models, like [methexis-inc/img2prompt](https://replicate.com/methexis-inc/i
5353
"an astronaut riding a horse"
5454
```
5555

56+
> [!NOTE]
57+
> You can also use the Replicate client asynchronously by prepending `async_` to the method name.
58+
>
59+
> Here's an example of how to run several predictions concurrently and wait for them all to complete:
60+
>
61+
> ```python
62+
> import asyncio
63+
> import replicate
64+
>
65+
> # https://replicate.com/stability-ai/sdxl
66+
> model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
67+
> prompts = [
68+
> f"A chariot pulled by a team of {count} rainbow unicorns"
69+
> for count in ["two", "four", "six", "eight"]
70+
> ]
71+
>
72+
> async with asyncio.TaskGroup() as tg:
73+
> tasks = [
74+
> tg.create_task(replicate.async_run(model_version, input={"prompt": prompt}))
75+
> for prompt in prompts
76+
> ]
77+
>
78+
> results = await asyncio.gather(*tasks)
79+
> print(results)
80+
> ```
81+
5682
## Run a model in the background
5783
5884
You can start a model and run it in the background:

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ packages = ["replicate"]
3737
[tool.mypy]
3838
plugins = "pydantic.mypy"
3939
exclude = ["tests/"]
40+
enable_incomplete_feature = ["Unpack"]
4041

4142
[tool.pylint.main]
4243
disable = [
@@ -48,6 +49,7 @@ disable = [
4849
"W0622", # Redefining built-in
4950
"R0903", # Too few public methods
5051
]
52+
good-names = ["id"]
5153

5254
[tool.ruff]
5355
select = [

replicate/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from replicate.client import Client
22

33
default_client = Client()
4+
45
run = default_client.run
6+
async_run = default_client.async_run
7+
58
collections = default_client.collections
69
hardware = default_client.hardware
710
deployments = default_client.deployments

replicate/client.py

Lines changed: 99 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,37 @@
11
import os
22
import random
3-
import re
43
import time
54
from datetime import datetime
65
from typing import (
76
Any,
7+
Dict,
88
Iterable,
99
Iterator,
1010
Mapping,
1111
Optional,
12+
Type,
1213
Union,
1314
)
1415

1516
import httpx
17+
from typing_extensions import Unpack
1618

1719
from replicate.__about__ import __version__
1820
from replicate.collection import Collections
1921
from replicate.deployment import Deployments
20-
from replicate.exceptions import ModelError, ReplicateError
21-
from replicate.hardware import Hardwares
22+
from replicate.exceptions import ReplicateError
23+
from replicate.hardware import HardwareNamespace as Hardware
2224
from replicate.model import Models
2325
from replicate.prediction import Predictions
24-
from replicate.schema import make_schema_backwards_compatible
26+
from replicate.run import async_run, run
2527
from replicate.training import Trainings
26-
from replicate.version import Version
2728

2829

2930
class Client:
3031
"""A Replicate API client library"""
3132

3233
__client: Optional[httpx.Client] = None
34+
__async_client: Optional[httpx.AsyncClient] = None
3335

3436
def __init__(
3537
self,
@@ -42,46 +44,45 @@ def __init__(
4244
super().__init__()
4345

4446
self._api_token = api_token
45-
self._base_url = (
46-
base_url
47-
or os.environ.get("REPLICATE_API_BASE_URL")
48-
or "https://api.replicate.com"
49-
)
50-
self._timeout = timeout or httpx.Timeout(
51-
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
52-
)
53-
self._transport = kwargs.pop("transport", httpx.HTTPTransport())
47+
self._base_url = base_url
48+
self._timeout = timeout
5449
self._client_kwargs = kwargs
5550

5651
self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
5752

5853
@property
5954
def _client(self) -> httpx.Client:
60-
if self.__client is None:
61-
headers = {
62-
"User-Agent": f"replicate-python/{__version__}",
63-
}
64-
65-
api_token = self._api_token or os.environ.get("REPLICATE_API_TOKEN")
66-
67-
if api_token is not None and api_token != "":
68-
headers["Authorization"] = f"Token {api_token}"
69-
70-
self.__client = httpx.Client(
55+
if not self.__client:
56+
self.__client = _build_httpx_client(
57+
httpx.Client,
58+
self._api_token,
59+
self._base_url,
60+
self._timeout,
7161
**self._client_kwargs,
72-
base_url=self._base_url,
73-
headers=headers,
74-
timeout=self._timeout,
75-
transport=RetryTransport(wrapped_transport=self._transport),
76-
)
62+
) # type: ignore[assignment]
63+
return self.__client # type: ignore[return-value]
7764

78-
return self.__client
65+
@property
66+
def _async_client(self) -> httpx.AsyncClient:
67+
if not self.__async_client:
68+
self.__async_client = _build_httpx_client(
69+
httpx.AsyncClient,
70+
self._api_token,
71+
self._base_url,
72+
self._timeout,
73+
**self._client_kwargs,
74+
) # type: ignore[assignment]
75+
return self.__async_client # type: ignore[return-value]
7976

8077
def _request(self, method: str, path: str, **kwargs) -> httpx.Response:
8178
resp = self._client.request(method, path, **kwargs)
79+
_raise_for_status(resp)
80+
81+
return resp
8282

83-
if 400 <= resp.status_code < 600:
84-
raise ReplicateError(resp.json()["detail"])
83+
async def _async_request(self, method: str, path: str, **kwargs) -> httpx.Response:
84+
resp = await self._async_client.request(method, path, **kwargs)
85+
_raise_for_status(resp)
8586

8687
return resp
8788

@@ -100,11 +101,11 @@ def deployments(self) -> Deployments:
100101
return Deployments(client=self)
101102

102103
@property
103-
def hardware(self) -> Hardwares:
104+
def hardware(self) -> Hardware:
104105
"""
105106
Namespace for operations related to hardware.
106107
"""
107-
return Hardwares(client=self)
108+
return Hardware(client=self)
108109

109110
@property
110111
def models(self) -> Models:
@@ -127,55 +128,29 @@ def trainings(self) -> Trainings:
127128
"""
128129
return Trainings(client=self)
129130

130-
def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: # noqa: ANN401
131+
def run(
132+
self,
133+
ref: str,
134+
input: Optional[Dict[str, Any]] = None,
135+
**params: Unpack["Predictions.CreatePredictionParams"],
136+
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
131137
"""
132138
Run a model and wait for its output.
133-
134-
Args:
135-
model_version: The model version to run, in the format `owner/name:version`
136-
kwargs: The input to the model, as a dictionary
137-
Returns:
138-
The output of the model
139139
"""
140-
# Split model_version into owner, name, version in format owner/name:version
141-
match = re.match(
142-
r"^(?P<owner>[^/]+)/(?P<name>[^:]+):(?P<version>.+)$", model_version
143-
)
144-
if not match:
145-
raise ReplicateError(
146-
f"Invalid model_version: {model_version}. Expected format: owner/name:version"
147-
)
148-
149-
owner = match.group("owner")
150-
name = match.group("name")
151-
version_id = match.group("version")
152140

153-
prediction = self.predictions.create(version=version_id, **kwargs)
141+
return run(self, ref, input, **params)
154142

155-
if owner and name:
156-
# FIXME: There should be a method for fetching a version without first fetching its model
157-
resp = self._request(
158-
"GET", f"/v1/models/{owner}/{name}/versions/{version_id}"
159-
)
160-
version = Version(**resp.json())
161-
162-
# Return an iterator of the output
163-
schema = make_schema_backwards_compatible(
164-
version.openapi_schema, version.cog_version
165-
)
166-
output = schema["components"]["schemas"]["Output"]
167-
if (
168-
output.get("type") == "array"
169-
and output.get("x-cog-array-type") == "iterator"
170-
):
171-
return prediction.output_iterator()
172-
173-
prediction.wait()
174-
175-
if prediction.status == "failed":
176-
raise ModelError(prediction.error)
143+
async def async_run(
144+
self,
145+
ref: str,
146+
input: Optional[Dict[str, Any]] = None,
147+
**params: Unpack["Predictions.CreatePredictionParams"],
148+
) -> Union[Any, Iterator[Any]]: # noqa: ANN401
149+
"""
150+
Run a model and wait for its output asynchronously.
151+
"""
177152

178-
return prediction.output
153+
return await async_run(self, ref, input, **params)
179154

180155

181156
# Adapted from https://github.com/encode/httpx/issues/108#issuecomment-1132753155
@@ -305,3 +280,49 @@ async def aclose(self) -> None:
305280

306281
def close(self) -> None:
307282
self._wrapped_transport.close() # type: ignore
283+
284+
285+
def _build_httpx_client(
286+
client_type: Type[Union[httpx.Client, httpx.AsyncClient]],
287+
api_token: Optional[str] = None,
288+
base_url: Optional[str] = None,
289+
timeout: Optional[httpx.Timeout] = None,
290+
**kwargs,
291+
) -> Union[httpx.Client, httpx.AsyncClient]:
292+
headers = {
293+
"User-Agent": f"replicate-python/{__version__}",
294+
}
295+
296+
if (
297+
api_token := api_token or os.environ.get("REPLICATE_API_TOKEN")
298+
) and api_token != "":
299+
headers["Authorization"] = f"Token {api_token}"
300+
301+
base_url = (
302+
base_url or os.environ.get("REPLICATE_BASE_URL") or "https://api.replicate.com"
303+
)
304+
if base_url == "":
305+
base_url = "https://api.replicate.com"
306+
307+
timeout = timeout or httpx.Timeout(
308+
5.0, read=30.0, write=30.0, connect=5.0, pool=10.0
309+
)
310+
311+
transport = kwargs.pop("transport", None) or (
312+
httpx.HTTPTransport()
313+
if client_type is httpx.Client
314+
else httpx.AsyncHTTPTransport()
315+
)
316+
317+
return client_type(
318+
base_url=base_url,
319+
headers=headers,
320+
timeout=timeout,
321+
transport=RetryTransport(wrapped_transport=transport), # type: ignore[arg-type]
322+
**kwargs,
323+
)
324+
325+
326+
def _raise_for_status(resp: httpx.Response) -> None:
327+
if 400 <= resp.status_code < 600:
328+
raise ReplicateError(resp.json()["detail"])

0 commit comments

Comments
 (0)