Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release-please-manifest.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
".": "0.2.0"
".": "0.2.1"
}
2 changes: 1 addition & 1 deletion .stats.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
configured_endpoints: 35
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-efbc8cc2d74644b213e161d3e11e0589d1cef181fb318ea02c8eb6b00f245713.yml
openapi_spec_hash: 13da0c06c900b61cd98ab678e024987a
config_hash: 8ef6787524fd12bfeb27f8c6acef3dca
config_hash: 84794ed69d841684ff08a8aa889ef103
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## 0.2.1 (2025-05-07)

Full Changelog: [v0.2.0...v0.2.1](https://github.com/replicate/replicate-python-stainless/compare/v0.2.0...v0.2.1)

### Documentation

* update example requests ([eb0ba44](https://github.com/replicate/replicate-python-stainless/commit/eb0ba44af5b5006e758c9d9e65312f88b52dc3f5))

## 0.2.0 (2025-05-07)

Full Changelog: [v0.1.0...v0.2.0](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0...v0.2.0)
Expand Down
69 changes: 45 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ client = Replicate(
bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)

account = client.account.get()
print(account.type)
prediction = client.predictions.get(
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
)
print(prediction.id)
```

While you can provide a `bearer_token` keyword argument,
Expand All @@ -55,8 +57,10 @@ client = AsyncReplicate(


async def main() -> None:
account = await client.account.get()
print(account.type)
prediction = await client.predictions.get(
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
)
print(prediction.id)


asyncio.run(main())
Expand Down Expand Up @@ -84,12 +88,12 @@ from replicate import Replicate

client = Replicate()

all_predictions = []
all_models = []
# Automatically fetches more pages as needed.
for prediction in client.predictions.list():
# Do something with prediction here
all_predictions.append(prediction)
print(all_predictions)
for model in client.models.list():
# Do something with model here
all_models.append(model)
print(all_models)
```

Or, asynchronously:
Expand All @@ -102,11 +106,11 @@ client = AsyncReplicate()


async def main() -> None:
all_predictions = []
all_models = []
# Iterate through items across all pages, issuing requests as needed.
async for prediction in client.predictions.list():
all_predictions.append(prediction)
print(all_predictions)
async for model in client.models.list():
all_models.append(model)
print(all_models)


asyncio.run(main())
Expand All @@ -115,7 +119,7 @@ asyncio.run(main())
Alternatively, you can use the `.has_next_page()`, `.next_page_info()`, or `.get_next_page()` methods for more granular control working with pages:

```python
first_page = await client.predictions.list()
first_page = await client.models.list()
if first_page.has_next_page():
print(f"will fetch next page using these details: {first_page.next_page_info()}")
next_page = await first_page.get_next_page()
Expand All @@ -127,11 +131,11 @@ if first_page.has_next_page():
Or just work directly with the returned data:

```python
first_page = await client.predictions.list()
first_page = await client.models.list()

print(f"next URL: {first_page.next}") # => "next URL: ..."
for prediction in first_page.results:
print(prediction.id)
for model in first_page.results:
print(model.cover_image_url)

# Remove `await` for non-async usage.
```
Expand Down Expand Up @@ -170,7 +174,10 @@ from replicate import Replicate
client = Replicate()

try:
client.account.get()
client.predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
except replicate.APIConnectionError as e:
print("The server could not be reached")
print(e.__cause__) # an underlying Exception, likely raised within httpx.
Expand Down Expand Up @@ -213,7 +220,10 @@ client = Replicate(
)

# Or, configure per-request:
client.with_options(max_retries=5).account.get()
client.with_options(max_retries=5).predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
```

### Timeouts
Expand All @@ -236,7 +246,10 @@ client = Replicate(
)

# Override per-request:
client.with_options(timeout=5.0).account.get()
client.with_options(timeout=5.0).predictions.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
```

On timeout, an `APITimeoutError` is thrown.
Expand Down Expand Up @@ -277,11 +290,16 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to
from replicate import Replicate

client = Replicate()
response = client.account.with_raw_response.get()
response = client.predictions.with_raw_response.create(
input={
"text": "Alice"
},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
)
print(response.headers.get('X-My-Header'))

account = response.parse() # get the object that `account.get()` would have returned
print(account.type)
prediction = response.parse() # get the object that `predictions.create()` would have returned
print(prediction.id)
```

These methods return an [`APIResponse`](https://github.com/replicate/replicate-python-stainless/tree/main/src/replicate/_response.py) object.
Expand All @@ -295,7 +313,10 @@ The above interface eagerly reads the full response body when you make the reque
To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods.

```python
with client.account.with_streaming_response.get() as response:
with client.predictions.with_streaming_response.create(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
) as response:
print(response.headers.get("X-My-Header"))

for line in response.iter_lines():
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "replicate-stainless"
version = "0.2.0"
version = "0.2.1"
description = "The official Python library for the replicate API"
dynamic = ["readme"]
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion src/replicate/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

__title__ = "replicate"
__version__ = "0.2.0" # x-release-please-version
__version__ = "0.2.1" # x-release-please-version
106 changes: 84 additions & 22 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from replicate import Replicate, AsyncReplicate, APIResponseValidationError
from replicate._types import Omit
from replicate._utils import maybe_transform
from replicate._models import BaseModel, FinalRequestOptions
from replicate._constants import RAW_RESPONSE_HEADER
from replicate._exceptions import APIStatusError, ReplicateError, APITimeoutError, APIResponseValidationError
Expand All @@ -32,6 +33,7 @@
BaseClient,
make_request_options,
)
from replicate.types.prediction_create_params import PredictionCreateParams

from .utils import update_env

Expand Down Expand Up @@ -740,20 +742,48 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.get("/account").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))

with pytest.raises(APITimeoutError):
self.client.get("/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}})
self.client.post(
"/predictions",
body=cast(
object,
maybe_transform(
dict(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
),
PredictionCreateParams,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

assert _get_open_connections(self.client) == 0

@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.get("/account").mock(return_value=httpx.Response(500))
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
self.client.get("/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}})
self.client.post(
"/predictions",
body=cast(
object,
maybe_transform(
dict(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
),
PredictionCreateParams,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

assert _get_open_connections(self.client) == 0

Expand Down Expand Up @@ -781,9 +811,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = client.account.with_raw_response.get()
response = client.predictions.with_raw_response.create(input={}, version="version")

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -805,9 +835,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": Omit()})
response = client.predictions.with_raw_response.create(
input={}, version="version", extra_headers={"x-stainless-retry-count": Omit()}
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0

Expand All @@ -828,9 +860,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": "42"})
response = client.predictions.with_raw_response.create(
input={}, version="version", extra_headers={"x-stainless-retry-count": "42"}
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"

Expand Down Expand Up @@ -1524,23 +1558,47 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.get("/account").mock(side_effect=httpx.TimeoutException("Test timeout error"))
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))

with pytest.raises(APITimeoutError):
await self.client.get(
"/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
await self.client.post(
"/predictions",
body=cast(
object,
maybe_transform(
dict(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
),
PredictionCreateParams,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

assert _get_open_connections(self.client) == 0

@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
@pytest.mark.respx(base_url=base_url)
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
respx_mock.get("/account").mock(return_value=httpx.Response(500))
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))

with pytest.raises(APIStatusError):
await self.client.get(
"/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
await self.client.post(
"/predictions",
body=cast(
object,
maybe_transform(
dict(
input={"text": "Alice"},
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
),
PredictionCreateParams,
),
),
cast_to=httpx.Response,
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
)

assert _get_open_connections(self.client) == 0
Expand Down Expand Up @@ -1570,9 +1628,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = await client.account.with_raw_response.get()
response = await client.predictions.with_raw_response.create(input={}, version="version")

assert response.retries_taken == failures_before_success
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
Expand All @@ -1595,9 +1653,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = await client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": Omit()})
response = await client.predictions.with_raw_response.create(
input={}, version="version", extra_headers={"x-stainless-retry-count": Omit()}
)

assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0

Expand All @@ -1619,9 +1679,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500)
return httpx.Response(200)

respx_mock.get("/account").mock(side_effect=retry_handler)
respx_mock.post("/predictions").mock(side_effect=retry_handler)

response = await client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": "42"})
response = await client.predictions.with_raw_response.create(
input={}, version="version", extra_headers={"x-stainless-retry-count": "42"}
)

assert response.http_request.headers.get("x-stainless-retry-count") == "42"

Expand Down