|
23 | 23 |
|
24 | 24 | from replicate import Replicate, AsyncReplicate, APIResponseValidationError |
25 | 25 | from replicate._types import Omit |
26 | | -from replicate._utils import maybe_transform |
27 | 26 | from replicate._models import BaseModel, FinalRequestOptions |
28 | | -from replicate._constants import RAW_RESPONSE_HEADER |
29 | 27 | from replicate._exceptions import APIStatusError, ReplicateError, APITimeoutError, APIResponseValidationError |
30 | 28 | from replicate._base_client import ( |
31 | 29 | DEFAULT_TIMEOUT, |
|
35 | 33 | DefaultAsyncHttpxClient, |
36 | 34 | make_request_options, |
37 | 35 | ) |
38 | | -from replicate.types.prediction_create_params import PredictionCreateParams |
39 | 36 |
|
40 | 37 | from .utils import update_env |
41 | 38 |
|
@@ -743,50 +740,27 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str |
743 | 740 |
|
744 | 741 | @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
745 | 742 | @pytest.mark.respx(base_url=base_url) |
746 | | - def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 743 | + def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: Replicate) -> None: |
747 | 744 | respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error")) |
748 | 745 |
|
749 | 746 | with pytest.raises(APITimeoutError): |
750 | | - self.client.post( |
751 | | - "/predictions", |
752 | | - body=cast( |
753 | | - object, |
754 | | - maybe_transform( |
755 | | - dict( |
756 | | - input={"text": "Alice"}, |
757 | | - version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", |
758 | | - ), |
759 | | - PredictionCreateParams, |
760 | | - ), |
761 | | - ), |
762 | | - cast_to=httpx.Response, |
763 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
764 | | - ) |
| 747 | + client.predictions.with_streaming_response.create( |
| 748 | + input={"text": "Alice"}, |
| 749 | + version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426", |
| 750 | + ).__enter__() |
765 | 751 |
|
766 | 752 | assert _get_open_connections(self.client) == 0 |
767 | 753 |
|
768 | 754 | @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
769 | 755 | @pytest.mark.respx(base_url=base_url) |
770 | | - def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 756 | + def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: Replicate) -> None: |
771 | 757 | respx_mock.post("/predictions").mock(return_value=httpx.Response(500)) |
772 | 758 |
|
773 | 759 | with pytest.raises(APIStatusError): |
774 | | - self.client.post( |
775 | | - "/predictions", |
776 | | - body=cast( |
777 | | - object, |
778 | | - maybe_transform( |
779 | | - dict( |
780 | | - input={"text": "Alice"}, |
781 | | - version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", |
782 | | - ), |
783 | | - PredictionCreateParams, |
784 | | - ), |
785 | | - ), |
786 | | - cast_to=httpx.Response, |
787 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
788 | | - ) |
789 | | - |
| 760 | + client.predictions.with_streaming_response.create( |
| 761 | + input={"text": "Alice"}, |
| 762 | + version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426", |
| 763 | + ).__enter__() |
790 | 764 | assert _get_open_connections(self.client) == 0 |
791 | 765 |
|
792 | 766 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
@@ -1615,50 +1589,31 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte |
1615 | 1589 |
|
1616 | 1590 | @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1617 | 1591 | @pytest.mark.respx(base_url=base_url) |
1618 | | - async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1592 | + async def test_retrying_timeout_errors_doesnt_leak( |
| 1593 | + self, respx_mock: MockRouter, async_client: AsyncReplicate |
| 1594 | + ) -> None: |
1619 | 1595 | respx_mock.post("/predictions").mock(side_effect=httpx.TimeoutException("Test timeout error")) |
1620 | 1596 |
|
1621 | 1597 | with pytest.raises(APITimeoutError): |
1622 | | - await self.client.post( |
1623 | | - "/predictions", |
1624 | | - body=cast( |
1625 | | - object, |
1626 | | - maybe_transform( |
1627 | | - dict( |
1628 | | - input={"text": "Alice"}, |
1629 | | - version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", |
1630 | | - ), |
1631 | | - PredictionCreateParams, |
1632 | | - ), |
1633 | | - ), |
1634 | | - cast_to=httpx.Response, |
1635 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1636 | | - ) |
| 1598 | + await async_client.predictions.with_streaming_response.create( |
| 1599 | + input={"text": "Alice"}, |
| 1600 | + version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426", |
| 1601 | + ).__aenter__() |
1637 | 1602 |
|
1638 | 1603 | assert _get_open_connections(self.client) == 0 |
1639 | 1604 |
|
1640 | 1605 | @mock.patch("replicate._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) |
1641 | 1606 | @pytest.mark.respx(base_url=base_url) |
1642 | | - async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: |
| 1607 | + async def test_retrying_status_errors_doesnt_leak( |
| 1608 | + self, respx_mock: MockRouter, async_client: AsyncReplicate |
| 1609 | + ) -> None: |
1643 | 1610 | respx_mock.post("/predictions").mock(return_value=httpx.Response(500)) |
1644 | 1611 |
|
1645 | 1612 | with pytest.raises(APIStatusError): |
1646 | | - await self.client.post( |
1647 | | - "/predictions", |
1648 | | - body=cast( |
1649 | | - object, |
1650 | | - maybe_transform( |
1651 | | - dict( |
1652 | | - input={"text": "Alice"}, |
1653 | | - version="replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", |
1654 | | - ), |
1655 | | - PredictionCreateParams, |
1656 | | - ), |
1657 | | - ), |
1658 | | - cast_to=httpx.Response, |
1659 | | - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, |
1660 | | - ) |
1661 | | - |
| 1613 | + await async_client.predictions.with_streaming_response.create( |
| 1614 | + input={"text": "Alice"}, |
| 1615 | + version="replicate/hello-world:9dcd6d78e7c6560c340d916fe32e9f24aabfa331e5cce95fe31f77fb03121426", |
| 1616 | + ).__aenter__() |
1662 | 1617 | assert _get_open_connections(self.client) == 0 |
1663 | 1618 |
|
1664 | 1619 | @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) |
|
0 commit comments