Skip to content

Commit eb0ba44

Browse files
docs: update example requests
1 parent 3f6d237 commit eb0ba44

File tree

3 files changed

+130
-47
lines changed

3 files changed

+130
-47
lines changed

.stats.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
configured_endpoints: 35
22
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-efbc8cc2d74644b213e161d3e11e0589d1cef181fb318ea02c8eb6b00f245713.yml
33
openapi_spec_hash: 13da0c06c900b61cd98ab678e024987a
4-
config_hash: 8ef6787524fd12bfeb27f8c6acef3dca
4+
config_hash: 84794ed69d841684ff08a8aa889ef103

README.md

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ client = Replicate(
3131
bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
3232
)
3333

34-
account = client.account.get()
35-
print(account.type)
34+
prediction = client.predictions.get(
35+
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
36+
)
37+
print(prediction.id)
3638
```
3739

3840
While you can provide a `bearer_token` keyword argument,
@@ -55,8 +57,10 @@ client = AsyncReplicate(
5557

5658

5759
async def main() -> None:
58-
account = await client.account.get()
59-
print(account.type)
60+
prediction = await client.predictions.get(
61+
prediction_id="gm3qorzdhgbfurvjtvhg6dckhu",
62+
)
63+
print(prediction.id)
6064

6165

6266
asyncio.run(main())
@@ -84,12 +88,12 @@ from replicate import Replicate
8488

8589
client = Replicate()
8690

87-
all_predictions = []
91+
all_models = []
8892
# Automatically fetches more pages as needed.
89-
for prediction in client.predictions.list():
90-
# Do something with prediction here
91-
all_predictions.append(prediction)
92-
print(all_predictions)
93+
for model in client.models.list():
94+
# Do something with model here
95+
all_models.append(model)
96+
print(all_models)
9397
```
9498

9599
Or, asynchronously:
@@ -102,11 +106,11 @@ client = AsyncReplicate()
102106

103107

104108
async def main() -> None:
105-
all_predictions = []
109+
all_models = []
106110
# Iterate through items across all pages, issuing requests as needed.
107-
async for prediction in client.predictions.list():
108-
all_predictions.append(prediction)
109-
print(all_predictions)
111+
async for model in client.models.list():
112+
all_models.append(model)
113+
print(all_models)
110114

111115

112116
asyncio.run(main())
@@ -115,7 +119,7 @@ asyncio.run(main())
115119
Alternatively, you can use the `.has_next_page()`, `.next_page_info()`, or `.get_next_page()` methods for more granular control working with pages:
116120

117121
```python
118-
first_page = await client.predictions.list()
122+
first_page = await client.models.list()
119123
if first_page.has_next_page():
120124
print(f"will fetch next page using these details: {first_page.next_page_info()}")
121125
next_page = await first_page.get_next_page()
@@ -127,11 +131,11 @@ if first_page.has_next_page():
127131
Or just work directly with the returned data:
128132

129133
```python
130-
first_page = await client.predictions.list()
134+
first_page = await client.models.list()
131135

132136
print(f"next URL: {first_page.next}") # => "next URL: ..."
133-
for prediction in first_page.results:
134-
print(prediction.id)
137+
for model in first_page.results:
138+
print(model.cover_image_url)
135139

136140
# Remove `await` for non-async usage.
137141
```
@@ -170,7 +174,10 @@ from replicate import Replicate
170174
client = Replicate()
171175

172176
try:
173-
client.account.get()
177+
client.predictions.create(
178+
input={"text": "Alice"},
179+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
180+
)
174181
except replicate.APIConnectionError as e:
175182
print("The server could not be reached")
176183
print(e.__cause__) # an underlying Exception, likely raised within httpx.
@@ -213,7 +220,10 @@ client = Replicate(
213220
)
214221

215222
# Or, configure per-request:
216-
client.with_options(max_retries=5).account.get()
223+
client.with_options(max_retries=5).predictions.create(
224+
input={"text": "Alice"},
225+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
226+
)
217227
```
218228

219229
### Timeouts
@@ -236,7 +246,10 @@ client = Replicate(
236246
)
237247

238248
# Override per-request:
239-
client.with_options(timeout=5.0).account.get()
249+
client.with_options(timeout=5.0).predictions.create(
250+
input={"text": "Alice"},
251+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
252+
)
240253
```
241254

242255
On timeout, an `APITimeoutError` is thrown.
@@ -277,11 +290,16 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to
277290
from replicate import Replicate
278291

279292
client = Replicate()
280-
response = client.account.with_raw_response.get()
293+
response = client.predictions.with_raw_response.create(
294+
input={
295+
"text": "Alice"
296+
},
297+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
298+
)
281299
print(response.headers.get('X-My-Header'))
282300

283-
account = response.parse() # get the object that `account.get()` would have returned
284-
print(account.type)
301+
prediction = response.parse() # get the object that `predictions.create()` would have returned
302+
print(prediction.id)
285303
```
286304

287305
These methods return an [`APIResponse`](https://github.com/replicate/replicate-python-stainless/tree/main/src/replicate/_response.py) object.
@@ -295,7 +313,10 @@ The above interface eagerly reads the full response body when you make the reque
295313
To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods.
296314

297315
```python
298-
with client.account.with_streaming_response.get() as response:
316+
with client.predictions.with_streaming_response.create(
317+
input={"text": "Alice"},
318+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
319+
) as response:
299320
print(response.headers.get("X-My-Header"))
300321

301322
for line in response.iter_lines():

tests/test_client.py

Lines changed: 84 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from replicate import Replicate, AsyncReplicate, APIResponseValidationError
2525
from replicate._types import Omit
26+
from replicate._utils import maybe_transform
2627
from replicate._models import BaseModel, FinalRequestOptions
2728
from replicate._constants import RAW_RESPONSE_HEADER
2829
from replicate._exceptions import APIStatusError, ReplicateError, APITimeoutError, APIResponseValidationError
@@ -32,6 +33,7 @@
3233
BaseClient,
3334
make_request_options,
3435
)
36+
from replicate.types.prediction_create_params import PredictionCreateParams
3537

3638
from .utils import update_env
3739

@@ -740,20 +742,48 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
740742
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
741743
@pytest.mark.respx(base_url=base_url)
742744
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
743-
respx_mock.get("/account").mock(side_effect=httpx.TimeoutException("Test timeout error"))
745+
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
744746

745747
with pytest.raises(APITimeoutError):
746-
self.client.get("/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}})
748+
self.client.post(
749+
"/predictions",
750+
body=cast(
751+
object,
752+
maybe_transform(
753+
dict(
754+
input={"text": "Alice"},
755+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
756+
),
757+
PredictionCreateParams,
758+
),
759+
),
760+
cast_to=httpx.Response,
761+
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
762+
)
747763

748764
assert _get_open_connections(self.client) == 0
749765

750766
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
751767
@pytest.mark.respx(base_url=base_url)
752768
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
753-
respx_mock.get("/account").mock(return_value=httpx.Response(500))
769+
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))
754770

755771
with pytest.raises(APIStatusError):
756-
self.client.get("/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}})
772+
self.client.post(
773+
"/predictions",
774+
body=cast(
775+
object,
776+
maybe_transform(
777+
dict(
778+
input={"text": "Alice"},
779+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
780+
),
781+
PredictionCreateParams,
782+
),
783+
),
784+
cast_to=httpx.Response,
785+
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
786+
)
757787

758788
assert _get_open_connections(self.client) == 0
759789

@@ -781,9 +811,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
781811
return httpx.Response(500)
782812
return httpx.Response(200)
783813

784-
respx_mock.get("/account").mock(side_effect=retry_handler)
814+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
785815

786-
response = client.account.with_raw_response.get()
816+
response = client.predictions.with_raw_response.create(input={}, version="version")
787817

788818
assert response.retries_taken == failures_before_success
789819
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
@@ -805,9 +835,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
805835
return httpx.Response(500)
806836
return httpx.Response(200)
807837

808-
respx_mock.get("/account").mock(side_effect=retry_handler)
838+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
809839

810-
response = client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": Omit()})
840+
response = client.predictions.with_raw_response.create(
841+
input={}, version="version", extra_headers={"x-stainless-retry-count": Omit()}
842+
)
811843

812844
assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
813845

@@ -828,9 +860,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
828860
return httpx.Response(500)
829861
return httpx.Response(200)
830862

831-
respx_mock.get("/account").mock(side_effect=retry_handler)
863+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
832864

833-
response = client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": "42"})
865+
response = client.predictions.with_raw_response.create(
866+
input={}, version="version", extra_headers={"x-stainless-retry-count": "42"}
867+
)
834868

835869
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
836870

@@ -1524,23 +1558,47 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
15241558
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
15251559
@pytest.mark.respx(base_url=base_url)
15261560
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1527-
respx_mock.get("/account").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1561+
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
15281562

15291563
with pytest.raises(APITimeoutError):
1530-
await self.client.get(
1531-
"/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
1564+
await self.client.post(
1565+
"/predictions",
1566+
body=cast(
1567+
object,
1568+
maybe_transform(
1569+
dict(
1570+
input={"text": "Alice"},
1571+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1572+
),
1573+
PredictionCreateParams,
1574+
),
1575+
),
1576+
cast_to=httpx.Response,
1577+
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
15321578
)
15331579

15341580
assert _get_open_connections(self.client) == 0
15351581

15361582
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
15371583
@pytest.mark.respx(base_url=base_url)
15381584
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1539-
respx_mock.get("/account").mock(return_value=httpx.Response(500))
1585+
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))
15401586

15411587
with pytest.raises(APIStatusError):
1542-
await self.client.get(
1543-
"/account", cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}
1588+
await self.client.post(
1589+
"/predictions",
1590+
body=cast(
1591+
object,
1592+
maybe_transform(
1593+
dict(
1594+
input={"text": "Alice"},
1595+
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1596+
),
1597+
PredictionCreateParams,
1598+
),
1599+
),
1600+
cast_to=httpx.Response,
1601+
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
15441602
)
15451603

15461604
assert _get_open_connections(self.client) == 0
@@ -1570,9 +1628,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
15701628
return httpx.Response(500)
15711629
return httpx.Response(200)
15721630

1573-
respx_mock.get("/account").mock(side_effect=retry_handler)
1631+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
15741632

1575-
response = await client.account.with_raw_response.get()
1633+
response = await client.predictions.with_raw_response.create(input={}, version="version")
15761634

15771635
assert response.retries_taken == failures_before_success
15781636
assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
@@ -1595,9 +1653,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
15951653
return httpx.Response(500)
15961654
return httpx.Response(200)
15971655

1598-
respx_mock.get("/account").mock(side_effect=retry_handler)
1656+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
15991657

1600-
response = await client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": Omit()})
1658+
response = await client.predictions.with_raw_response.create(
1659+
input={}, version="version", extra_headers={"x-stainless-retry-count": Omit()}
1660+
)
16011661

16021662
assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
16031663

@@ -1619,9 +1679,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
16191679
return httpx.Response(500)
16201680
return httpx.Response(200)
16211681

1622-
respx_mock.get("/account").mock(side_effect=retry_handler)
1682+
respx_mock.post("/predictions").mock(side_effect=retry_handler)
16231683

1624-
response = await client.account.with_raw_response.get(extra_headers={"x-stainless-retry-count": "42"})
1684+
response = await client.predictions.with_raw_response.create(
1685+
input={}, version="version", extra_headers={"x-stainless-retry-count": "42"}
1686+
)
16251687

16261688
assert response.http_request.headers.get("x-stainless-retry-count") == "42"
16271689

0 commit comments

Comments
 (0)