Skip to content

Commit 4edf967

Browse files
authored
Merge pull request #18 from replicate/release-please--branches--main--changes--next
release: 0.3.0
2 parents dadf411 + f9c8204 commit 4edf967

25 files changed

+1776
-49
lines changed

.release-please-manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
{
2-
".": "0.2.1"
2+
".": "0.3.0"
33
}

.stats.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 35
2-
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-efbc8cc2d74644b213e161d3e11e0589d1cef181fb318ea02c8eb6b00f245713.yml
3-
openapi_spec_hash: 13da0c06c900b61cd98ab678e024987a
2+
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-88cf5fe1f5accb56ae9fbb31c0df00d1552762d4c558d16d8547894ae95e8ccb.yml
3+
openapi_spec_hash: 43283d20f335a04241cce165452ff50e
44
config_hash: 84794ed69d841684ff08a8aa889ef103

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# Changelog
22

3+
## 0.3.0 (2025-05-08)
4+
5+
Full Changelog: [v0.2.1...v0.3.0](https://github.com/replicate/replicate-python-stainless/compare/v0.2.1...v0.3.0)
6+
7+
### Features
8+
9+
* **api:** api update ([0e4a103](https://github.com/replicate/replicate-python-stainless/commit/0e4a10391bebf0cae929c8d11ccd7415d1785500))
10+
311
## 0.2.1 (2025-05-07)
412

513
Full Changelog: [v0.2.0...v0.2.1](https://github.com/replicate/replicate-python-stainless/compare/v0.2.0...v0.2.1)

api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Methods:
116116

117117
- <code title="post /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">create</a>(\*\*<a href="src/replicate/types/prediction_create_params.py">params</a>) -> <a href="./src/replicate/types/prediction.py">Prediction</a></code>
118118
- <code title="get /predictions">client.predictions.<a href="./src/replicate/resources/predictions.py">list</a>(\*\*<a href="src/replicate/types/prediction_list_params.py">params</a>) -> <a href="./src/replicate/types/prediction.py">SyncCursorURLPageWithCreatedFilters[Prediction]</a></code>
119-
- <code title="post /predictions/{prediction_id}/cancel">client.predictions.<a href="./src/replicate/resources/predictions.py">cancel</a>(\*, prediction_id) -> None</code>
119+
- <code title="post /predictions/{prediction_id}/cancel">client.predictions.<a href="./src/replicate/resources/predictions.py">cancel</a>(\*, prediction_id) -> <a href="./src/replicate/types/prediction.py">Prediction</a></code>
120120
- <code title="get /predictions/{prediction_id}">client.predictions.<a href="./src/replicate/resources/predictions.py">get</a>(\*, prediction_id) -> <a href="./src/replicate/types/prediction.py">Prediction</a></code>
121121

122122
# Trainings

examples/run_a_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import rich
2+
3+
import replicate
4+
5+
outputs = replicate.run(
6+
"black-forest-labs/flux-schnell",
7+
input={"prompt": "astronaut riding a rocket like a horse"},
8+
)
9+
rich.print(outputs)
10+
for index, output in enumerate(outputs):
11+
with open(f"output_{index}.webp", "wb") as file:
12+
file.write(output.read())

examples/run_async.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import asyncio
2+
3+
from replicate import AsyncReplicate
4+
5+
replicate = AsyncReplicate()
6+
7+
# https://replicate.com/stability-ai/sdxl
8+
model_version = "stability-ai/sdxl:39ed52f2a78e934b3ba6e2a89f5b1c712de7dfea535525255b1aa35c5565e08b"
9+
prompts = [f"A chariot pulled by a team of {count} rainbow unicorns" for count in ["two", "four", "six", "eight"]]
10+
11+
12+
async def main() -> None:
13+
# Create tasks with asyncio.gather directly
14+
tasks = [replicate.run(model_version, input={"prompt": prompt}) for prompt in prompts]
15+
16+
results = await asyncio.gather(*tasks)
17+
print(results)
18+
19+
20+
asyncio.run(main())

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "replicate-stainless"
3-
version = "0.2.1"
3+
version = "0.3.0"
44
description = "The official Python library for the replicate API"
55
dynamic = ["readme"]
66
license = "Apache-2.0"

src/replicate/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ._version import __title__, __version__
2323
from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse
2424
from ._constants import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, DEFAULT_CONNECTION_LIMITS
25+
from .lib._files import FileOutput as FileOutput, AsyncFileOutput as AsyncFileOutput
2526
from ._exceptions import (
2627
APIError,
2728
ConflictError,
@@ -38,6 +39,7 @@
3839
UnprocessableEntityError,
3940
APIResponseValidationError,
4041
)
42+
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
4143
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
4244
from ._utils._logs import setup_logging as _setup_logging
4345

@@ -80,6 +82,11 @@
8082
"DEFAULT_CONNECTION_LIMITS",
8183
"DefaultHttpxClient",
8284
"DefaultAsyncHttpxClient",
85+
"FileOutput",
86+
"AsyncFileOutput",
87+
"Model",
88+
"Version",
89+
"ModelVersionIdentifier",
8390
]
8491

8592
_setup_logging()
@@ -230,6 +237,7 @@ def _reset_client() -> None: # type: ignore[reportUnusedFunction]
230237

231238

232239
from ._module_client import (
240+
run as run,
233241
files as files,
234242
models as models,
235243
account as account,

src/replicate/_client.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping
7-
from typing_extensions import Self, override
6+
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
7+
from typing_extensions import Self, Unpack, override
88

99
import httpx
1010

11+
from replicate.lib._files import FileEncodingStrategy
12+
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
13+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
14+
1115
from . import _exceptions
1216
from ._qs import Querystring
1317
from ._types import (
@@ -171,6 +175,10 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
171175
def with_streaming_response(self) -> ReplicateWithStreamedResponse:
172176
return ReplicateWithStreamedResponse(self)
173177

178+
@cached_property
179+
def poll_interval(self) -> float:
180+
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
181+
174182
@property
175183
@override
176184
def qs(self) -> Querystring:
@@ -191,6 +199,54 @@ def default_headers(self) -> dict[str, str | Omit]:
191199
**self._custom_headers,
192200
}
193201

202+
def run(
203+
self,
204+
ref: Union[Model, Version, ModelVersionIdentifier, str],
205+
*,
206+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
207+
use_file_output: bool = True,
208+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
209+
**params: Unpack[PredictionCreateParamsWithoutVersion],
210+
) -> Any:
211+
"""
212+
Run a model prediction.
213+
214+
Args:
215+
ref: Reference to the model or version to run. Can be:
216+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
217+
- A string with owner/name format (e.g. "replicate/hello-world")
218+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
219+
- A Model instance with owner and name attributes
220+
- A Version instance with id attribute
221+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
222+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
223+
use_file_output: If True (default), convert output URLs to FileOutput objects
224+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
225+
If an integer, wait up to that many seconds.
226+
**params: Additional parameters to pass to the prediction creation endpoint including
227+
the required "input" dictionary with model-specific parameters
228+
229+
Returns:
230+
The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
231+
a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
232+
the model returns.
233+
234+
Raises:
235+
ModelError: If the model run fails
236+
ValueError: If the reference format is invalid
237+
TypeError: If both wait and prefer parameters are provided
238+
"""
239+
from .lib._predictions import run
240+
241+
return run(
242+
self,
243+
ref,
244+
wait=wait,
245+
use_file_output=use_file_output,
246+
file_encoding_strategy=file_encoding_strategy,
247+
**params,
248+
)
249+
194250
def copy(
195251
self,
196252
*,
@@ -393,6 +449,10 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
393449
def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse:
394450
return AsyncReplicateWithStreamedResponse(self)
395451

452+
@cached_property
453+
def poll_interval(self) -> float:
454+
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
455+
396456
@property
397457
@override
398458
def qs(self) -> Querystring:
@@ -413,6 +473,54 @@ def default_headers(self) -> dict[str, str | Omit]:
413473
**self._custom_headers,
414474
}
415475

476+
async def run(
477+
self,
478+
ref: Union[Model, Version, ModelVersionIdentifier, str],
479+
*,
480+
use_file_output: bool = True,
481+
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
482+
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
483+
**params: Unpack[PredictionCreateParamsWithoutVersion],
484+
) -> Any:
485+
"""
486+
Run a model prediction asynchronously.
487+
488+
Args:
489+
ref: Reference to the model or version to run. Can be:
490+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
491+
- A string with owner/name format (e.g. "replicate/hello-world")
492+
- A string with owner/name:version format (e.g. "replicate/hello-world:5c7d5dc6...")
493+
- A Model instance with owner and name attributes
494+
- A Version instance with id attribute
495+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
496+
use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
497+
file_encoding_strategy: Strategy for encoding file inputs, options are "base64" or "url"
498+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
499+
If an integer, wait up to that many seconds.
500+
**params: Additional parameters to pass to the prediction creation endpoint including
501+
the required "input" dictionary with model-specific parameters
502+
503+
Returns:
504+
The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
505+
a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
506+
the model returns.
507+
508+
Raises:
509+
ModelError: If the model run fails
510+
ValueError: If the reference format is invalid
511+
TypeError: If both wait and prefer parameters are provided
512+
"""
513+
from .lib._predictions import async_run
514+
515+
return await async_run(
516+
self,
517+
ref,
518+
wait=wait,
519+
use_file_output=use_file_output,
520+
file_encoding_strategy=file_encoding_strategy,
521+
**params,
522+
)
523+
416524
def copy(
417525
self,
418526
*,

src/replicate/_exceptions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import httpx
88

9+
from replicate.types.prediction import Prediction
10+
911
__all__ = [
1012
"BadRequestError",
1113
"AuthenticationError",
@@ -15,6 +17,7 @@
1517
"UnprocessableEntityError",
1618
"RateLimitError",
1719
"InternalServerError",
20+
"ModelError",
1821
]
1922

2023

@@ -106,3 +109,13 @@ class RateLimitError(APIStatusError):
106109

107110
class InternalServerError(APIStatusError):
108111
pass
112+
113+
114+
class ModelError(ReplicateError):
115+
"""An error from user's code in a model."""
116+
117+
prediction: Prediction
118+
119+
def __init__(self, prediction: Prediction) -> None:
120+
self.prediction = prediction
121+
super().__init__(prediction.error)

0 commit comments

Comments
 (0)