2323
2424from replicate import Replicate , AsyncReplicate , APIResponseValidationError
2525from replicate ._types import Omit
26+ from replicate ._utils import maybe_transform
2627from replicate ._models import BaseModel , FinalRequestOptions
2728from replicate ._constants import RAW_RESPONSE_HEADER
2829from replicate ._exceptions import APIStatusError , ReplicateError , APITimeoutError , APIResponseValidationError
3233 BaseClient ,
3334 make_request_options ,
3435)
36+ from replicate .types .prediction_create_params import PredictionCreateParams
3537
3638from .utils import update_env
3739
@@ -740,20 +742,48 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str
740742 @mock .patch ("replicate._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
741743 @pytest .mark .respx (base_url = base_url )
742744 def test_retrying_timeout_errors_doesnt_leak (self , respx_mock : MockRouter ) -> None :
743- respx_mock .get ("/account " ).mock (side_effect = httpx .TimeoutException ("Test timeout error" ))
745+ respx_mock .post ("/predictions " ).mock (side_effect = httpx .TimeoutException ("Test timeout error" ))
744746
745747 with pytest .raises (APITimeoutError ):
746- self .client .get ("/account" , cast_to = httpx .Response , options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }})
748+ self .client .post (
749+ "/predictions" ,
750+ body = cast (
751+ object ,
752+ maybe_transform (
753+ dict (
754+ input = {"text" : "Alice" },
755+ version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
756+ ),
757+ PredictionCreateParams ,
758+ ),
759+ ),
760+ cast_to = httpx .Response ,
761+ options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }},
762+ )
747763
748764 assert _get_open_connections (self .client ) == 0
749765
750766 @mock .patch ("replicate._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
751767 @pytest .mark .respx (base_url = base_url )
752768 def test_retrying_status_errors_doesnt_leak (self , respx_mock : MockRouter ) -> None :
753- respx_mock .get ("/account " ).mock (return_value = httpx .Response (500 ))
769+ respx_mock .post ("/predictions " ).mock (return_value = httpx .Response (500 ))
754770
755771 with pytest .raises (APIStatusError ):
756- self .client .get ("/account" , cast_to = httpx .Response , options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }})
772+ self .client .post (
773+ "/predictions" ,
774+ body = cast (
775+ object ,
776+ maybe_transform (
777+ dict (
778+ input = {"text" : "Alice" },
779+ version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
780+ ),
781+ PredictionCreateParams ,
782+ ),
783+ ),
784+ cast_to = httpx .Response ,
785+ options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }},
786+ )
757787
758788 assert _get_open_connections (self .client ) == 0
759789
@@ -781,9 +811,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
781811 return httpx .Response (500 )
782812 return httpx .Response (200 )
783813
784- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
814+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
785815
786- response = client .account .with_raw_response .get ( )
816+ response = client .predictions .with_raw_response .create ( input = {}, version = "version" )
787817
788818 assert response .retries_taken == failures_before_success
789819 assert int (response .http_request .headers .get ("x-stainless-retry-count" )) == failures_before_success
@@ -805,9 +835,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
805835 return httpx .Response (500 )
806836 return httpx .Response (200 )
807837
808- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
838+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
809839
810- response = client .account .with_raw_response .get (extra_headers = {"x-stainless-retry-count" : Omit ()})
840+ response = client .predictions .with_raw_response .create (
841+ input = {}, version = "version" , extra_headers = {"x-stainless-retry-count" : Omit ()}
842+ )
811843
812844 assert len (response .http_request .headers .get_list ("x-stainless-retry-count" )) == 0
813845
@@ -828,9 +860,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
828860 return httpx .Response (500 )
829861 return httpx .Response (200 )
830862
831- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
863+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
832864
833- response = client .account .with_raw_response .get (extra_headers = {"x-stainless-retry-count" : "42" })
865+ response = client .predictions .with_raw_response .create (
866+ input = {}, version = "version" , extra_headers = {"x-stainless-retry-count" : "42" }
867+ )
834868
835869 assert response .http_request .headers .get ("x-stainless-retry-count" ) == "42"
836870
@@ -1524,23 +1558,47 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte
15241558 @mock .patch ("replicate._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
15251559 @pytest .mark .respx (base_url = base_url )
15261560 async def test_retrying_timeout_errors_doesnt_leak (self , respx_mock : MockRouter ) -> None :
1527- respx_mock .get ("/account " ).mock (side_effect = httpx .TimeoutException ("Test timeout error" ))
1561+ respx_mock .post ("/predictions " ).mock (side_effect = httpx .TimeoutException ("Test timeout error" ))
15281562
15291563 with pytest .raises (APITimeoutError ):
1530- await self .client .get (
1531- "/account" , cast_to = httpx .Response , options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }}
1564+ await self .client .post (
1565+ "/predictions" ,
1566+ body = cast (
1567+ object ,
1568+ maybe_transform (
1569+ dict (
1570+ input = {"text" : "Alice" },
1571+ version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1572+ ),
1573+ PredictionCreateParams ,
1574+ ),
1575+ ),
1576+ cast_to = httpx .Response ,
1577+ options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }},
15321578 )
15331579
15341580 assert _get_open_connections (self .client ) == 0
15351581
15361582 @mock .patch ("replicate._base_client.BaseClient._calculate_retry_timeout" , _low_retry_timeout )
15371583 @pytest .mark .respx (base_url = base_url )
15381584 async def test_retrying_status_errors_doesnt_leak (self , respx_mock : MockRouter ) -> None :
1539- respx_mock .get ("/account " ).mock (return_value = httpx .Response (500 ))
1585+ respx_mock .post ("/predictions " ).mock (return_value = httpx .Response (500 ))
15401586
15411587 with pytest .raises (APIStatusError ):
1542- await self .client .get (
1543- "/account" , cast_to = httpx .Response , options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }}
1588+ await self .client .post (
1589+ "/predictions" ,
1590+ body = cast (
1591+ object ,
1592+ maybe_transform (
1593+ dict (
1594+ input = {"text" : "Alice" },
1595+ version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
1596+ ),
1597+ PredictionCreateParams ,
1598+ ),
1599+ ),
1600+ cast_to = httpx .Response ,
1601+ options = {"headers" : {RAW_RESPONSE_HEADER : "stream" }},
15441602 )
15451603
15461604 assert _get_open_connections (self .client ) == 0
@@ -1570,9 +1628,9 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
15701628 return httpx .Response (500 )
15711629 return httpx .Response (200 )
15721630
1573- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
1631+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
15741632
1575- response = await client .account .with_raw_response .get ( )
1633+ response = await client .predictions .with_raw_response .create ( input = {}, version = "version" )
15761634
15771635 assert response .retries_taken == failures_before_success
15781636 assert int (response .http_request .headers .get ("x-stainless-retry-count" )) == failures_before_success
@@ -1595,9 +1653,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
15951653 return httpx .Response (500 )
15961654 return httpx .Response (200 )
15971655
1598- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
1656+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
15991657
1600- response = await client .account .with_raw_response .get (extra_headers = {"x-stainless-retry-count" : Omit ()})
1658+ response = await client .predictions .with_raw_response .create (
1659+ input = {}, version = "version" , extra_headers = {"x-stainless-retry-count" : Omit ()}
1660+ )
16011661
16021662 assert len (response .http_request .headers .get_list ("x-stainless-retry-count" )) == 0
16031663
@@ -1619,9 +1679,11 @@ def retry_handler(_request: httpx.Request) -> httpx.Response:
16191679 return httpx .Response (500 )
16201680 return httpx .Response (200 )
16211681
1622- respx_mock .get ("/account " ).mock (side_effect = retry_handler )
1682+ respx_mock .post ("/predictions " ).mock (side_effect = retry_handler )
16231683
1624- response = await client .account .with_raw_response .get (extra_headers = {"x-stainless-retry-count" : "42" })
1684+ response = await client .predictions .with_raw_response .create (
1685+ input = {}, version = "version" , extra_headers = {"x-stainless-retry-count" : "42" }
1686+ )
16251687
16261688 assert response .http_request .headers .get ("x-stainless-retry-count" ) == "42"
16271689
0 commit comments