Skip to content

Commit 3dfe4f7

Browse files
fix(tests): fix: tests which call HTTP endpoints directly with the example parameters
1 parent 4f54c7a commit 3dfe4f7

File tree

1 file changed

+24
-69
lines changed

1 file changed

+24
-69
lines changed

tests/test_client.py

Lines changed: 24 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323

2424
from replicate import Replicate, AsyncReplicate, APIResponseValidationError
2525
from replicate._types import Omit
26-
from replicate._utils import maybe_transform
2726
from replicate._models import BaseModel, FinalRequestOptions
28-
from replicate._constants import RAW_RESPONSE_HEADER
2927
from replicate._exceptions import APIStatusError, ReplicateError, APITimeoutError, APIResponseValidationError
3028
from replicate._base_client import (
3129
DEFAULT_TIMEOUT,
@@ -35,7 +33,6 @@
3533
DefaultAsyncHttpxClient,
3634
make_request_options,
3735
)
38-
from replicate.types.prediction_create_params import PredictionCreateParams
3936

4037
from .utils import update_env
4138

@@ -743,50 +740,27 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
743740

744741
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
745742
@pytest.mark.respx(base_url=base_url)
746-
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
743+
def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Replicate) -> None:
747744
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
748745

749746
with pytest.raises(APITimeoutError):
750-
self.client.post(
751-
"/predictions",
752-
body=cast(
753-
object,
754-
maybe_transform(
755-
dict(
756-
input={"text": "Alice"},
757-
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
758-
),
759-
PredictionCreateParams,
760-
),
761-
),
762-
cast_to=httpx.Response,
763-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
764-
)
747+
client.predictions.with_streaming_response.create(
748+
input={"text": "Alice"},
749+
version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426",
750+
).__enter__()
765751

766752
assert _get_open_connections(self.client) == 0
767753

768754
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
769755
@pytest.mark.respx(base_url=base_url)
770-
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
756+
def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Replicate) -> None:
771757
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))
772758

773759
with pytest.raises(APIStatusError):
774-
self.client.post(
775-
"/predictions",
776-
body=cast(
777-
object,
778-
maybe_transform(
779-
dict(
780-
input={"text": "Alice"},
781-
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
782-
),
783-
PredictionCreateParams,
784-
),
785-
),
786-
cast_to=httpx.Response,
787-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
788-
)
789-
760+
client.predictions.with_streaming_response.create(
761+
input={"text": "Alice"},
762+
version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426",
763+
).__enter__()
790764
assert _get_open_connections(self.client) == 0
791765

792766
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])
@@ -1615,50 +1589,31 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
16151589

16161590
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
16171591
@pytest.mark.respx(base_url=base_url)
1618-
async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1592+
async def test_retrying_timeout_errors_doesnt_leak(
1593+
self, respx_mock: MockRouter, async_client: AsyncReplicate
1594+
) -> None:
16191595
respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
16201596

16211597
with pytest.raises(APITimeoutError):
1622-
await self.client.post(
1623-
"/predictions",
1624-
body=cast(
1625-
object,
1626-
maybe_transform(
1627-
dict(
1628-
input={"text": "Alice"},
1629-
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1630-
),
1631-
PredictionCreateParams,
1632-
),
1633-
),
1634-
cast_to=httpx.Response,
1635-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1636-
)
1598+
await async_client.predictions.with_streaming_response.create(
1599+
input={"text": "Alice"},
1600+
version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426",
1601+
).__aenter__()
16371602

16381603
assert _get_open_connections(self.client) == 0
16391604

16401605
@mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
16411606
@pytest.mark.respx(base_url=base_url)
1642-
async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1607+
async def test_retrying_status_errors_doesnt_leak(
1608+
self, respx_mock: MockRouter, async_client: AsyncReplicate
1609+
) -> None:
16431610
respx_mock.post("/predictions").mock(return_value=httpx.Response(500))
16441611

16451612
with pytest.raises(APIStatusError):
1646-
await self.client.post(
1647-
"/predictions",
1648-
body=cast(
1649-
object,
1650-
maybe_transform(
1651-
dict(
1652-
input={"text": "Alice"},
1653-
version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
1654-
),
1655-
PredictionCreateParams,
1656-
),
1657-
),
1658-
cast_to=httpx.Response,
1659-
options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1660-
)
1661-
1613+
await async_client.predictions.with_streaming_response.create(
1614+
input={"text": "Alice"},
1615+
version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426",
1616+
).__aenter__()
16621617
assert _get_open_connections(self.client) == 0
16631618

16641619
@pytest.mark.parametrize("failures_before_success", [0, 2, 4])

0 commit comments

Comments
 (0)