From 00eab7702f8f2699ce9b3070f23202278ac21855 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:14:13 +0000 Subject: [PATCH 01/12] fix(api)!: use correct env var for bearer token The correct env var to use is REPLICATE_API_TOKEN, not REPLICATE_CLIENT_BEARER_TOKEN --- .stats.yml | 2 +- README.md | 10 +++------- src/replicate/_client.py | 12 ++++++------ tests/test_client.py | 4 ++-- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/.stats.yml b/.stats.yml index 91aadf3..263f930 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 27 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3 -config_hash: 810de4c2eee1a7649263cff01f00da7c +config_hash: 64fd304bf0ff077a74053ce79a255106 diff --git a/README.md b/README.md index 0015b4c..23b57ae 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,7 @@ import os from replicate import ReplicateClient client = ReplicateClient( - bearer_token=os.environ.get( - "REPLICATE_CLIENT_BEARER_TOKEN" - ), # This is the default and can be omitted + bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) accounts = client.accounts.list() @@ -39,7 +37,7 @@ print(accounts.type) While you can provide a `bearer_token` keyword argument, we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) -to add `REPLICATE_CLIENT_BEARER_TOKEN="My Bearer Token"` to your `.env` file +to add `REPLICATE_API_TOKEN="My Bearer Token"` to your `.env` file so that your Bearer Token is not stored in source control. ## Async usage @@ -52,9 +50,7 @@ import asyncio from replicate import AsyncReplicateClient client = AsyncReplicateClient( - bearer_token=os.environ.get( - "REPLICATE_CLIENT_BEARER_TOKEN" - ), # This is the default and can be omitted + bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 06e86ff..b59e805 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -88,13 +88,13 @@ def __init__( ) -> None: """Construct a new synchronous ReplicateClient client instance. - This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided. + This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN") + bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: raise ReplicateClientError( - "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable" + "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token @@ -270,13 +270,13 @@ def __init__( ) -> None: """Construct a new async AsyncReplicateClient client instance. - This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided. + This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN") + bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: raise ReplicateClientError( - "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable" + "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token diff --git a/tests/test_client.py b/tests/test_client.py index bd5b999..279675f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -346,7 +346,7 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {bearer_token}" with pytest.raises(ReplicateClientError): - with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}): + with update_env(**{"REPLICATE_API_TOKEN": Omit()}): client2 = ReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 @@ -1126,7 +1126,7 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {bearer_token}" with pytest.raises(ReplicateClientError): - with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}): + with update_env(**{"REPLICATE_API_TOKEN": Omit()}): client2 = AsyncReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 From 7ebd59873181c74dbaa035ac599abcbbefb3ee62 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 23:30:55 +0000 Subject: [PATCH 02/12] feat(api): api update --- .stats.yml | 4 +- api.md | 20 ++++- src/replicate/resources/models/versions.py | 11 ++- src/replicate/resources/trainings.py | 46 ++++++------ src/replicate/types/__init__.py | 3 + .../types/deployment_list_response.py | 6 +- src/replicate/types/models/__init__.py | 1 + .../version_create_training_response.py | 74 +++++++++++++++++++ src/replicate/types/prediction_output.py | 2 +- .../types/training_cancel_response.py | 74 +++++++++++++++++++ src/replicate/types/training_list_response.py | 74 +++++++++++++++++++ .../types/training_retrieve_response.py | 74 +++++++++++++++++++ tests/api_resources/models/test_versions.py | 18 +++-- tests/api_resources/test_trainings.py | 39 +++++----- 14 files changed, 379 insertions(+), 67 deletions(-) create mode 100644 src/replicate/types/models/version_create_training_response.py create mode 100644 src/replicate/types/training_cancel_response.py create mode 100644 src/replicate/types/training_list_response.py create mode 100644 src/replicate/types/training_retrieve_response.py diff --git a/.stats.yml b/.stats.yml index 263f930..1448ef2 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 27 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml -openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3 +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml +openapi_spec_hash: 4423bf747e228484547b441468a9f156 config_hash: 64fd304bf0ff077a74053ce79a255106 diff --git a/api.md b/api.md index 60e91e6..509868d 100644 --- a/api.md +++ b/api.md @@ -75,12 +75,18 @@ Methods: ## Versions +Types: + +```python +from replicate.types.models import VersionCreateTrainingResponse +``` + Methods: - client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None - client.models.versions.list(model_name, \*, model_owner) -> None - client.models.versions.delete(version_id, \*, model_owner, model_name) -> None -- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> None +- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse # Predictions @@ -99,11 +105,17 @@ Methods: # Trainings +Types: + +```python +from replicate.types import TrainingRetrieveResponse, TrainingListResponse, TrainingCancelResponse +``` + Methods: -- client.trainings.retrieve(training_id) -> None -- client.trainings.list() -> None -- client.trainings.cancel(training_id) -> None +- client.trainings.retrieve(training_id) -> TrainingRetrieveResponse +- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse] +- client.trainings.cancel(training_id) -> TrainingCancelResponse # Webhooks diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py index eb411ef..203bb5d 100644 --- a/src/replicate/resources/models/versions.py +++ b/src/replicate/resources/models/versions.py @@ -22,6 +22,7 @@ ) from ..._base_client import make_request_options from ...types.models import version_create_training_params +from ...types.models.version_create_training_response import VersionCreateTrainingResponse __all__ = ["VersionsResource", "AsyncVersionsResource"] @@ -281,7 +282,7 @@ def create_training( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> VersionCreateTrainingResponse: """ Start a new training of the model version you specify. @@ -398,7 +399,6 @@ def create_training( raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings", body=maybe_transform( @@ -413,7 +413,7 @@ def create_training( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=VersionCreateTrainingResponse, ) @@ -672,7 +672,7 @@ async def create_training( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> VersionCreateTrainingResponse: """ Start a new training of the model version you specify. @@ -789,7 +789,6 @@ async def create_training( raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings", body=await async_maybe_transform( @@ -804,7 +803,7 @@ async def create_training( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=VersionCreateTrainingResponse, ) diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py index df64c19..46c5872 100644 --- a/src/replicate/resources/trainings.py +++ b/src/replicate/resources/trainings.py @@ -4,7 +4,7 @@ import httpx -from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -13,7 +13,11 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from .._base_client import make_request_options +from ..pagination import SyncCursorURLPage, AsyncCursorURLPage +from .._base_client import AsyncPaginator, make_request_options +from ..types.training_list_response import TrainingListResponse +from ..types.training_cancel_response import TrainingCancelResponse +from ..types.training_retrieve_response import TrainingRetrieveResponse __all__ = ["TrainingsResource", "AsyncTrainingsResource"] @@ -48,7 +52,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingRetrieveResponse: """ Get the current state of a training. @@ -123,13 +127,12 @@ def retrieve( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._get( f"/trainings/{training_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingRetrieveResponse, ) def list( @@ -141,7 +144,7 @@ def list( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> SyncCursorURLPage[TrainingListResponse]: """ Get a paginated list of all trainings created by the user or organization associated with the provided API token. @@ -207,13 +210,13 @@ def list( `version` will be the unique ID of model version used to create the training. """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( + return self._get_api_list( "/trainings", + page=SyncCursorURLPage[TrainingListResponse], options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=TrainingListResponse, ) def cancel( @@ -226,7 +229,7 @@ def cancel( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingCancelResponse: """ Cancel a training @@ -241,13 +244,12 @@ def cancel( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( f"/trainings/{training_id}/cancel", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingCancelResponse, ) @@ -281,7 +283,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingRetrieveResponse: """ Get the current state of a training. @@ -356,16 +358,15 @@ async def retrieve( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._get( f"/trainings/{training_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingRetrieveResponse, ) - async def list( + def list( self, *, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -374,7 +375,7 @@ async def list( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> AsyncPaginator[TrainingListResponse, AsyncCursorURLPage[TrainingListResponse]]: """ Get a paginated list of all trainings created by the user or organization associated with the provided API token. @@ -440,13 +441,13 @@ async def list( `version` will be the unique ID of model version used to create the training. """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( + return self._get_api_list( "/trainings", + page=AsyncCursorURLPage[TrainingListResponse], options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=TrainingListResponse, ) async def cancel( @@ -459,7 +460,7 @@ async def cancel( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingCancelResponse: """ Cancel a training @@ -474,13 +475,12 @@ async def cancel( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( f"/trainings/{training_id}/cancel", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingCancelResponse, ) diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index e2b3c58..7d2d8f4 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -9,11 +9,14 @@ from .account_list_response import AccountListResponse as AccountListResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse from .prediction_list_params import PredictionListParams as PredictionListParams +from .training_list_response import TrainingListResponse as TrainingListResponse from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams from .deployment_list_response import DeploymentListResponse as DeploymentListResponse from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams from .prediction_create_params import PredictionCreateParams as PredictionCreateParams +from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse +from .training_retrieve_response import TrainingRetrieveResponse as TrainingRetrieveResponse from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams diff --git a/src/replicate/types/deployment_list_response.py b/src/replicate/types/deployment_list_response.py index 9f64487..2606d38 100644 --- a/src/replicate/types/deployment_list_response.py +++ b/src/replicate/types/deployment_list_response.py @@ -49,11 +49,7 @@ class CurrentRelease(BaseModel): """The model identifier string in the format of `{model_owner}/{model_name}`.""" number: Optional[int] = None - """The release number. - - This is an auto-incrementing integer that starts at 1, and is set automatically - when a deployment is created. - """ + """The release number.""" version: Optional[str] = None """The ID of the model version used in the release.""" diff --git a/src/replicate/types/models/__init__.py b/src/replicate/types/models/__init__.py index 98e1a3a..2d89bc6 100644 --- a/src/replicate/types/models/__init__.py +++ b/src/replicate/types/models/__init__.py @@ -3,3 +3,4 @@ from __future__ import annotations from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams +from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse diff --git a/src/replicate/types/models/version_create_training_response.py b/src/replicate/types/models/version_create_training_response.py new file mode 100644 index 0000000..5b005fe --- /dev/null +++ b/src/replicate/types/models/version_create_training_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class VersionCreateTrainingResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/prediction_output.py b/src/replicate/types/prediction_output.py index 2c46100..8b320b5 100644 --- a/src/replicate/types/prediction_output.py +++ b/src/replicate/types/prediction_output.py @@ -6,5 +6,5 @@ __all__ = ["PredictionOutput"] PredictionOutput: TypeAlias = Union[ - Optional[Dict[str, object]], Optional[List[object]], Optional[str], Optional[float], Optional[bool] + Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[str], Optional[float], Optional[bool] ] diff --git a/src/replicate/types/training_cancel_response.py b/src/replicate/types/training_cancel_response.py new file mode 100644 index 0000000..9af512b --- /dev/null +++ b/src/replicate/types/training_cancel_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingCancelResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingCancelResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/training_list_response.py b/src/replicate/types/training_list_response.py new file mode 100644 index 0000000..02adf3a --- /dev/null +++ b/src/replicate/types/training_list_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingListResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingListResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/training_retrieve_response.py b/src/replicate/types/training_retrieve_response.py new file mode 100644 index 0000000..55b1173 --- /dev/null +++ b/src/replicate/types/training_retrieve_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingRetrieveResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingRetrieveResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py index 295205e..7e6d583 100644 --- a/tests/api_resources/models/test_versions.py +++ b/tests/api_resources/models/test_versions.py @@ -8,6 +8,8 @@ import pytest from replicate import ReplicateClient, AsyncReplicateClient +from tests.utils import assert_matches_type +from replicate.types.models import VersionCreateTrainingResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -205,7 +207,7 @@ def test_method_create_training(self, client: ReplicateClient) -> None: destination="destination", input={}, ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -219,7 +221,7 @@ def test_method_create_training_with_all_params(self, client: ReplicateClient) - webhook="webhook", webhook_events_filter=["start"], ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -235,7 +237,7 @@ def test_raw_response_create_training(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -251,7 +253,7 @@ def test_streaming_response_create_training(self, client: ReplicateClient) -> No assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) assert cast(Any, response.is_closed) is True @@ -479,7 +481,7 @@ async def test_method_create_training(self, async_client: AsyncReplicateClient) destination="destination", input={}, ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -493,7 +495,7 @@ async def test_method_create_training_with_all_params(self, async_client: AsyncR webhook="webhook", webhook_events_filter=["start"], ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -509,7 +511,7 @@ async def test_raw_response_create_training(self, async_client: AsyncReplicateCl assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = await response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -525,7 +527,7 @@ async def test_streaming_response_create_training(self, async_client: AsyncRepli assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = await response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py index f59d874..68b1a85 100644 --- a/tests/api_resources/test_trainings.py +++ b/tests/api_resources/test_trainings.py @@ -8,6 +8,9 @@ import pytest from replicate import ReplicateClient, AsyncReplicateClient +from tests.utils import assert_matches_type +from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingRetrieveResponse +from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -21,7 +24,7 @@ def test_method_retrieve(self, client: ReplicateClient) -> None: training = client.trainings.retrieve( "training_id", ) - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -33,7 +36,7 @@ def test_raw_response_retrieve(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -45,7 +48,7 @@ def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -61,7 +64,7 @@ def test_path_params_retrieve(self, client: ReplicateClient) -> None: @parametrize def test_method_list(self, client: ReplicateClient) -> None: training = client.trainings.list() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -71,7 +74,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -81,7 +84,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -91,7 +94,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None: training = client.trainings.cancel( "training_id", ) - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -103,7 +106,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -115,7 +118,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -137,7 +140,7 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None training = await async_client.trainings.retrieve( "training_id", ) - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -149,7 +152,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) - assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -161,7 +164,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -177,7 +180,7 @@ async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: training = await async_client.trainings.list() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -187,7 +190,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -197,7 +200,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -207,7 +210,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: training = await async_client.trainings.cancel( "training_id", ) - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -219,7 +222,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -231,7 +234,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True From f1e4d140104ff317b94cb2dd88ec850a9b8bce54 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 03:30:58 +0000 Subject: [PATCH 03/12] chore(internal): bump pyright version --- pyproject.toml | 2 +- requirements-dev.lock | 2 +- src/replicate/_base_client.py | 6 +++++- src/replicate/_models.py | 1 - src/replicate/_utils/_typing.py | 2 +- tests/conftest.py | 2 +- tests/test_models.py | 2 +- 7 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d9681a8..8eb0e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ Repository = "https://github.com/replicate/replicate-python-stainless" managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "pyright>=1.1.359", + "pyright==1.1.399", "mypy", "respx", "pytest", diff --git a/requirements-dev.lock b/requirements-dev.lock index e9ea098..86eea12 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -69,7 +69,7 @@ pydantic-core==2.27.1 # via pydantic pygments==2.18.0 # via rich -pyright==1.1.392.post0 +pyright==1.1.399 pytest==8.3.3 # via pytest-asyncio pytest-asyncio==0.24.0 diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index d55fecf..3d26f28 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -98,7 +98,11 @@ _AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) if TYPE_CHECKING: - from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + from httpx._config import ( + DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage] + ) + + HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG else: try: from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT diff --git a/src/replicate/_models.py b/src/replicate/_models.py index 3493571..58b9263 100644 --- a/src/replicate/_models.py +++ b/src/replicate/_models.py @@ -19,7 +19,6 @@ ) import pydantic -import pydantic.generics from pydantic.fields import FieldInfo from ._types import ( diff --git a/src/replicate/_utils/_typing.py b/src/replicate/_utils/_typing.py index 1958820..1bac954 100644 --- a/src/replicate/_utils/_typing.py +++ b/src/replicate/_utils/_typing.py @@ -110,7 +110,7 @@ class MyResponse(Foo[_T]): ``` """ cls = cast(object, get_origin(typ) or typ) - if cls in generic_bases: + if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains] # we're given the class directly return extract_type_arg(typ, index) diff --git a/tests/conftest.py b/tests/conftest.py index 97007f1..7d1538b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from replicate import ReplicateClient, AsyncReplicateClient if TYPE_CHECKING: - from _pytest.fixtures import FixtureRequest + from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] pytest.register_assert_rewrite("tests.utils") diff --git a/tests/test_models.py b/tests/test_models.py index ee374fa..25f3d43 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -832,7 +832,7 @@ class B(BaseModel): @pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: - Alias = TypeAliasType("Alias", str) + Alias = TypeAliasType("Alias", str) # pyright: ignore class Model(BaseModel): alias: Alias From c1d6ed59ed0f06012922ec6d0bae376852523d81 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 03:31:38 +0000 Subject: [PATCH 04/12] chore(internal): base client updates --- src/replicate/_base_client.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index 3d26f28..fb06997 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -119,6 +119,7 @@ class PageInfo: url: URL | NotGiven params: Query | NotGiven + json: Body | NotGiven @overload def __init__( @@ -134,19 +135,30 @@ def __init__( params: Query, ) -> None: ... + @overload + def __init__( + self, + *, + json: Body, + ) -> None: ... + def __init__( self, *, url: URL | NotGiven = NOT_GIVEN, + json: Body | NotGiven = NOT_GIVEN, params: Query | NotGiven = NOT_GIVEN, ) -> None: self.url = url + self.json = json self.params = params @override def __repr__(self) -> str: if self.url: return f"{self.__class__.__name__}(url={self.url})" + if self.json: + return f"{self.__class__.__name__}(json={self.json})" return f"{self.__class__.__name__}(params={self.params})" @@ -195,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: options.url = str(url) return options + if not isinstance(info.json, NotGiven): + if not is_mapping(info.json): + raise TypeError("Pagination is only supported with mappings") + + if not options.json_data: + options.json_data = {**info.json} + else: + if not is_mapping(options.json_data): + raise TypeError("Pagination is only supported with mappings") + + options.json_data = {**options.json_data, **info.json} + return options + raise ValueError("Unexpected PageInfo state") From a812f1b8f719c301ebf98a9d919a382e411ef247 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Thu, 17 Apr 2025 11:15:40 +0000 Subject: [PATCH 05/12] replace `retrieve` with `get` for method names --- .stats.yml | 2 +- api.md | 14 +- .../resources/deployments/deployments.py | 310 ++++++------ src/replicate/resources/models/models.py | 460 +++++++++--------- src/replicate/resources/models/versions.py | 260 +++++----- src/replicate/resources/predictions.py | 432 ++++++++-------- src/replicate/resources/trainings.py | 262 +++++----- src/replicate/types/__init__.py | 4 +- ...response.py => deployment_get_response.py} | 4 +- ...e_response.py => training_get_response.py} | 4 +- tests/api_resources/models/test_versions.py | 160 +++--- tests/api_resources/test_deployments.py | 210 ++++---- tests/api_resources/test_models.py | 208 ++++---- tests/api_resources/test_predictions.py | 168 +++---- tests/api_resources/test_trainings.py | 120 ++--- 15 files changed, 1309 insertions(+), 1309 deletions(-) rename src/replicate/types/{deployment_retrieve_response.py => deployment_get_response.py} (91%) rename src/replicate/types/{training_retrieve_response.py => training_get_response.py} (94%) diff --git a/.stats.yml b/.stats.yml index 1448ef2..190fe36 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 27 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml openapi_spec_hash: 4423bf747e228484547b441468a9f156 -config_hash: 64fd304bf0ff077a74053ce79a255106 +config_hash: d1d273c0d97d034d24c7eac8ef51d2ac diff --git a/api.md b/api.md index 509868d..e972297 100644 --- a/api.md +++ b/api.md @@ -11,19 +11,19 @@ Types: ```python from replicate.types import ( DeploymentCreateResponse, - DeploymentRetrieveResponse, DeploymentUpdateResponse, DeploymentListResponse, + DeploymentGetResponse, ) ``` Methods: - client.deployments.create(\*\*params) -> DeploymentCreateResponse -- client.deployments.retrieve(deployment_name, \*, deployment_owner) -> DeploymentRetrieveResponse - client.deployments.update(deployment_name, \*, deployment_owner, \*\*params) -> DeploymentUpdateResponse - client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse] - client.deployments.delete(deployment_name, \*, deployment_owner) -> None +- client.deployments.get(deployment_name, \*, deployment_owner) -> DeploymentGetResponse - client.deployments.list_em_all() -> None ## Predictions @@ -68,10 +68,10 @@ from replicate.types import ModelListResponse Methods: - client.models.create(\*\*params) -> None -- client.models.retrieve(model_name, \*, model_owner) -> None - client.models.list() -> SyncCursorURLPage[ModelListResponse] - client.models.delete(model_name, \*, model_owner) -> None - client.models.create_prediction(model_name, \*, model_owner, \*\*params) -> Prediction +- client.models.get(model_name, \*, model_owner) -> None ## Versions @@ -83,10 +83,10 @@ from replicate.types.models import VersionCreateTrainingResponse Methods: -- client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None - client.models.versions.list(model_name, \*, model_owner) -> None - client.models.versions.delete(version_id, \*, model_owner, model_name) -> None - client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse +- client.models.versions.get(version_id, \*, model_owner, model_name) -> None # Predictions @@ -99,23 +99,23 @@ from replicate.types import Prediction, PredictionOutput, PredictionRequest Methods: - client.predictions.create(\*\*params) -> Prediction -- client.predictions.retrieve(prediction_id) -> Prediction - client.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction] - client.predictions.cancel(prediction_id) -> None +- client.predictions.get(prediction_id) -> Prediction # Trainings Types: ```python -from replicate.types import TrainingRetrieveResponse, TrainingListResponse, TrainingCancelResponse +from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingGetResponse ``` Methods: -- client.trainings.retrieve(training_id) -> TrainingRetrieveResponse - client.trainings.list() -> SyncCursorURLPage[TrainingListResponse] - client.trainings.cancel(training_id) -> TrainingCancelResponse +- client.trainings.get(training_id) -> TrainingGetResponse # Webhooks diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py index 0211b67..7e2fac7 100644 --- a/src/replicate/resources/deployments/deployments.py +++ b/src/replicate/resources/deployments/deployments.py @@ -28,10 +28,10 @@ ) from ...pagination import SyncCursorURLPage, AsyncCursorURLPage from ..._base_client import AsyncPaginator, make_request_options +from ...types.deployment_get_response import DeploymentGetResponse from ...types.deployment_list_response import DeploymentListResponse from ...types.deployment_create_response import DeploymentCreateResponse from ...types.deployment_update_response import DeploymentUpdateResponse -from ...types.deployment_retrieve_response import DeploymentRetrieveResponse __all__ = ["DeploymentsResource", "AsyncDeploymentsResource"] @@ -165,77 +165,6 @@ def create( cast_to=DeploymentCreateResponse, ) - def retrieve( - self, - deployment_name: str, - *, - deployment_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DeploymentRetrieveResponse: - """ - Get information about a deployment by name including the current release. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/deployments/replicate/my-app-image-generator - ``` - - The response will be a JSON object describing the deployment: - - ```json - { - "owner": "acme", - "name": "my-app-image-generator", - "current_release": { - "number": 1, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", - "created_at": "2024-02-15T16:32:57.018467Z", - "created_by": { - "type": "organization", - "username": "acme", - "name": "Acme Corp, Inc.", - "avatar_url": "https://cdn.replicate.com/avatars/acme.png", - "github_url": "https://github.com/acme" - }, - "configuration": { - "hardware": "gpu-t4", - "min_instances": 1, - "max_instances": 5 - } - } - } - ``` - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not deployment_owner: - raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") - if not deployment_name: - raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") - return self._get( - f"/deployments/{deployment_owner}/{deployment_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=DeploymentRetrieveResponse, - ) - def update( self, deployment_name: str, @@ -454,6 +383,77 @@ def delete( cast_to=NoneType, ) + def get( + self, + deployment_name: str, + *, + deployment_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DeploymentGetResponse: + """ + Get information about a deployment by name including the current release. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/deployments/replicate/my-app-image-generator + ``` + + The response will be a JSON object describing the deployment: + + ```json + { + "owner": "acme", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "avatar_url": "https://cdn.replicate.com/avatars/acme.png", + "github_url": "https://github.com/acme" + }, + "configuration": { + "hardware": "gpu-t4", + "min_instances": 1, + "max_instances": 5 + } + } + } + ``` + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not deployment_owner: + raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") + if not deployment_name: + raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") + return self._get( + f"/deployments/{deployment_owner}/{deployment_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=DeploymentGetResponse, + ) + def list_em_all( self, *, @@ -628,77 +628,6 @@ async def create( cast_to=DeploymentCreateResponse, ) - async def retrieve( - self, - deployment_name: str, - *, - deployment_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DeploymentRetrieveResponse: - """ - Get information about a deployment by name including the current release. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/deployments/replicate/my-app-image-generator - ``` - - The response will be a JSON object describing the deployment: - - ```json - { - "owner": "acme", - "name": "my-app-image-generator", - "current_release": { - "number": 1, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", - "created_at": "2024-02-15T16:32:57.018467Z", - "created_by": { - "type": "organization", - "username": "acme", - "name": "Acme Corp, Inc.", - "avatar_url": "https://cdn.replicate.com/avatars/acme.png", - "github_url": "https://github.com/acme" - }, - "configuration": { - "hardware": "gpu-t4", - "min_instances": 1, - "max_instances": 5 - } - } - } - ``` - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not deployment_owner: - raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") - if not deployment_name: - raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") - return await self._get( - f"/deployments/{deployment_owner}/{deployment_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=DeploymentRetrieveResponse, - ) - async def update( self, deployment_name: str, @@ -917,6 +846,77 @@ async def delete( cast_to=NoneType, ) + async def get( + self, + deployment_name: str, + *, + deployment_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DeploymentGetResponse: + """ + Get information about a deployment by name including the current release. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/deployments/replicate/my-app-image-generator + ``` + + The response will be a JSON object describing the deployment: + + ```json + { + "owner": "acme", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "avatar_url": "https://cdn.replicate.com/avatars/acme.png", + "github_url": "https://github.com/acme" + }, + "configuration": { + "hardware": "gpu-t4", + "min_instances": 1, + "max_instances": 5 + } + } + } + ``` + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not deployment_owner: + raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") + if not deployment_name: + raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") + return await self._get( + f"/deployments/{deployment_owner}/{deployment_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=DeploymentGetResponse, + ) + async def list_em_all( self, *, @@ -969,9 +969,6 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.create = to_raw_response_wrapper( deployments.create, ) - self.retrieve = to_raw_response_wrapper( - deployments.retrieve, - ) self.update = to_raw_response_wrapper( deployments.update, ) @@ -981,6 +978,9 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.delete = to_raw_response_wrapper( deployments.delete, ) + self.get = to_raw_response_wrapper( + deployments.get, + ) self.list_em_all = to_raw_response_wrapper( deployments.list_em_all, ) @@ -997,9 +997,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.create = async_to_raw_response_wrapper( deployments.create, ) - self.retrieve = async_to_raw_response_wrapper( - deployments.retrieve, - ) self.update = async_to_raw_response_wrapper( deployments.update, ) @@ -1009,6 +1006,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.delete = async_to_raw_response_wrapper( deployments.delete, ) + self.get = async_to_raw_response_wrapper( + deployments.get, + ) self.list_em_all = async_to_raw_response_wrapper( deployments.list_em_all, ) @@ -1025,9 +1025,6 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.create = to_streamed_response_wrapper( deployments.create, ) - self.retrieve = to_streamed_response_wrapper( - deployments.retrieve, - ) self.update = to_streamed_response_wrapper( deployments.update, ) @@ -1037,6 +1034,9 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.delete = to_streamed_response_wrapper( deployments.delete, ) + self.get = to_streamed_response_wrapper( + deployments.get, + ) self.list_em_all = to_streamed_response_wrapper( deployments.list_em_all, ) @@ -1053,9 +1053,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.create = async_to_streamed_response_wrapper( deployments.create, ) - self.retrieve = async_to_streamed_response_wrapper( - deployments.retrieve, - ) self.update = async_to_streamed_response_wrapper( deployments.update, ) @@ -1065,6 +1062,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.delete = async_to_streamed_response_wrapper( deployments.delete, ) + self.get = async_to_streamed_response_wrapper( + deployments.get, + ) self.list_em_all = async_to_streamed_response_wrapper( deployments.list_em_all, ) diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 77145ae..6d8c6eb 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -174,115 +174,6 @@ def create( cast_to=NoneType, ) - def retrieve( - self, - model_name: str, - *, - model_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world - ``` - - The response will be a model object in the following format: - - ```json - { - "url": "https://replicate.com/replicate/hello-world", - "owner": "replicate", - "name": "hello-world", - "description": "A tiny model that says hello", - "visibility": "public", - "github_url": "https://github.com/replicate/cog-examples", - "paper_url": null, - "license_url": null, - "run_count": 5681081, - "cover_image_url": "...", - "default_example": {...}, - "latest_version": {...}, - } - ``` - - The model object includes the - [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) - for the latest version of the model. - - Here's an example showing how to fetch the model with cURL and display its input - schema with [jq](https://stedolan.github.io/jq/): - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world \\ - | jq ".latest_version.openapi_schema.components.schemas.Input" - ``` - - This will return the following JSON object: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `cover_image_url` string is an HTTPS URL for an image file. This can be: - - - An image uploaded by the model author. - - The output file of the example prediction, if the model author has not set a - cover image. - - The input file of the example prediction, if the model author has not set a - cover image and the example prediction has no output file. - - A generic fallback image. - - The `default_example` object is a [prediction](#predictions.get) created with - this model. - - The `latest_version` object is the model's most recently pushed - [version](#models.versions.get). - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( - f"/models/{model_owner}/{model_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, *, @@ -514,6 +405,115 @@ def create_prediction( cast_to=Prediction, ) + def get( + self, + model_name: str, + *, + model_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world + ``` + + The response will be a model object in the following format: + + ```json + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": {...}, + "latest_version": {...}, + } + ``` + + The model object includes the + [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) + for the latest version of the model. + + Here's an example showing how to fetch the model with cURL and display its input + schema with [jq](https://stedolan.github.io/jq/): + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world \\ + | jq ".latest_version.openapi_schema.components.schemas.Input" + ``` + + This will return the following JSON object: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `cover_image_url` string is an HTTPS URL for an image file. This can be: + + - An image uploaded by the model author. + - The output file of the example prediction, if the model author has not set a + cover image. + - The input file of the example prediction, if the model author has not set a + cover image and the example prediction has no output file. + - A generic fallback image. + + The `default_example` object is a [prediction](#predictions.get) created with + this model. + + The `latest_version` object is the model's most recently pushed + [version](#models.versions.get). + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return self._get( + f"/models/{model_owner}/{model_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class AsyncModelsResource(AsyncAPIResource): @cached_property @@ -651,115 +651,6 @@ async def create( cast_to=NoneType, ) - async def retrieve( - self, - model_name: str, - *, - model_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world - ``` - - The response will be a model object in the following format: - - ```json - { - "url": "https://replicate.com/replicate/hello-world", - "owner": "replicate", - "name": "hello-world", - "description": "A tiny model that says hello", - "visibility": "public", - "github_url": "https://github.com/replicate/cog-examples", - "paper_url": null, - "license_url": null, - "run_count": 5681081, - "cover_image_url": "...", - "default_example": {...}, - "latest_version": {...}, - } - ``` - - The model object includes the - [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) - for the latest version of the model. - - Here's an example showing how to fetch the model with cURL and display its input - schema with [jq](https://stedolan.github.io/jq/): - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world \\ - | jq ".latest_version.openapi_schema.components.schemas.Input" - ``` - - This will return the following JSON object: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `cover_image_url` string is an HTTPS URL for an image file. This can be: - - - An image uploaded by the model author. - - The output file of the example prediction, if the model author has not set a - cover image. - - The input file of the example prediction, if the model author has not set a - cover image and the example prediction has no output file. - - A generic fallback image. - - The `default_example` object is a [prediction](#predictions.get) created with - this model. - - The `latest_version` object is the model's most recently pushed - [version](#models.versions.get). - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( - f"/models/{model_owner}/{model_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, *, @@ -991,6 +882,115 @@ async def create_prediction( cast_to=Prediction, ) + async def get( + self, + model_name: str, + *, + model_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world + ``` + + The response will be a model object in the following format: + + ```json + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": {...}, + "latest_version": {...}, + } + ``` + + The model object includes the + [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) + for the latest version of the model. + + Here's an example showing how to fetch the model with cURL and display its input + schema with [jq](https://stedolan.github.io/jq/): + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world \\ + | jq ".latest_version.openapi_schema.components.schemas.Input" + ``` + + This will return the following JSON object: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `cover_image_url` string is an HTTPS URL for an image file. This can be: + + - An image uploaded by the model author. + - The output file of the example prediction, if the model author has not set a + cover image. + - The input file of the example prediction, if the model author has not set a + cover image and the example prediction has no output file. + - A generic fallback image. + + The `default_example` object is a [prediction](#predictions.get) created with + this model. + + The `latest_version` object is the model's most recently pushed + [version](#models.versions.get). + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return await self._get( + f"/models/{model_owner}/{model_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class ModelsResourceWithRawResponse: def __init__(self, models: ModelsResource) -> None: @@ -999,9 +999,6 @@ def __init__(self, models: ModelsResource) -> None: self.create = to_raw_response_wrapper( models.create, ) - self.retrieve = to_raw_response_wrapper( - models.retrieve, - ) self.list = to_raw_response_wrapper( models.list, ) @@ -1011,6 +1008,9 @@ def __init__(self, models: ModelsResource) -> None: self.create_prediction = to_raw_response_wrapper( models.create_prediction, ) + self.get = to_raw_response_wrapper( + models.get, + ) @cached_property def versions(self) -> VersionsResourceWithRawResponse: @@ -1024,9 +1024,6 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create = async_to_raw_response_wrapper( models.create, ) - self.retrieve = async_to_raw_response_wrapper( - models.retrieve, - ) self.list = async_to_raw_response_wrapper( models.list, ) @@ -1036,6 +1033,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create_prediction = async_to_raw_response_wrapper( models.create_prediction, ) + self.get = async_to_raw_response_wrapper( + models.get, + ) @cached_property def versions(self) -> AsyncVersionsResourceWithRawResponse: @@ -1049,9 +1049,6 @@ def __init__(self, models: ModelsResource) -> None: self.create = to_streamed_response_wrapper( models.create, ) - self.retrieve = to_streamed_response_wrapper( - models.retrieve, - ) self.list = to_streamed_response_wrapper( models.list, ) @@ -1061,6 +1058,9 @@ def __init__(self, models: ModelsResource) -> None: self.create_prediction = to_streamed_response_wrapper( models.create_prediction, ) + self.get = to_streamed_response_wrapper( + models.get, + ) @cached_property def versions(self) -> VersionsResourceWithStreamingResponse: @@ -1074,9 +1074,6 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create = async_to_streamed_response_wrapper( models.create, ) - self.retrieve = async_to_streamed_response_wrapper( - models.retrieve, - ) self.list = async_to_streamed_response_wrapper( models.list, ) @@ -1086,6 +1083,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create_prediction = async_to_streamed_response_wrapper( models.create_prediction, ) + self.get = async_to_streamed_response_wrapper( + models.get, + ) @cached_property def versions(self) -> AsyncVersionsResourceWithStreamingResponse: diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py index 203bb5d..2065fe0 100644 --- a/src/replicate/resources/models/versions.py +++ b/src/replicate/resources/models/versions.py @@ -47,101 +47,6 @@ def with_streaming_response(self) -> VersionsResourceWithStreamingResponse: """ return VersionsResourceWithStreamingResponse(self) - def retrieve( - self, - version_id: str, - *, - model_owner: str, - model_name: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa - ``` - - The response will be the version object: - - ```json - { - "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "created_at": "2022-04-26T19:29:04.418669Z", - "cog_version": "0.3.0", - "openapi_schema": {...} - } - ``` - - Every model describes its inputs and outputs with - [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject) - in the `openapi_schema` property. - - The `openapi_schema.components.schemas.Input` property for the - [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks - like this: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "x-order": 0, - "type": "string", - "title": "Text", - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `openapi_schema.components.schemas.Output` property for the - [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks - like this: - - ```json - { - "type": "string", - "title": "Output" - } - ``` - - For more details, see the docs on - [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types) - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - if not version_id: - raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( - f"/models/{model_owner}/{model_name}/versions/{version_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, model_name: str, @@ -416,28 +321,7 @@ def create_training( cast_to=VersionCreateTrainingResponse, ) - -class AsyncVersionsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers - """ - return AsyncVersionsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response - """ - return AsyncVersionsResourceWithStreamingResponse(self) - - async def retrieve( + def get( self, version_id: str, *, @@ -524,7 +408,7 @@ async def retrieve( if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( + return self._get( f"/models/{model_owner}/{model_name}/versions/{version_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -532,6 +416,27 @@ async def retrieve( cast_to=NoneType, ) + +class AsyncVersionsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers + """ + return AsyncVersionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response + """ + return AsyncVersionsResourceWithStreamingResponse(self) + async def list( self, model_name: str, @@ -806,14 +711,106 @@ async def create_training( cast_to=VersionCreateTrainingResponse, ) + async def get( + self, + version_id: str, + *, + model_owner: str, + model_name: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa + ``` + + The response will be the version object: + + ```json + { + "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "created_at": "2022-04-26T19:29:04.418669Z", + "cog_version": "0.3.0", + "openapi_schema": {...} + } + ``` + + Every model describes its inputs and outputs with + [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject) + in the `openapi_schema` property. + + The `openapi_schema.components.schemas.Input` property for the + [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks + like this: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "x-order": 0, + "type": "string", + "title": "Text", + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `openapi_schema.components.schemas.Output` property for the + [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks + like this: + + ```json + { + "type": "string", + "title": "Output" + } + ``` + + For more details, see the docs on + [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types) + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + if not version_id: + raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return await self._get( + f"/models/{model_owner}/{model_name}/versions/{version_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class VersionsResourceWithRawResponse: def __init__(self, versions: VersionsResource) -> None: self._versions = versions - self.retrieve = to_raw_response_wrapper( - versions.retrieve, - ) self.list = to_raw_response_wrapper( versions.list, ) @@ -823,15 +820,15 @@ def __init__(self, versions: VersionsResource) -> None: self.create_training = to_raw_response_wrapper( versions.create_training, ) + self.get = to_raw_response_wrapper( + versions.get, + ) class AsyncVersionsResourceWithRawResponse: def __init__(self, versions: AsyncVersionsResource) -> None: self._versions = versions - self.retrieve = async_to_raw_response_wrapper( - versions.retrieve, - ) self.list = async_to_raw_response_wrapper( versions.list, ) @@ -841,15 +838,15 @@ def __init__(self, versions: AsyncVersionsResource) -> None: self.create_training = async_to_raw_response_wrapper( versions.create_training, ) + self.get = async_to_raw_response_wrapper( + versions.get, + ) class VersionsResourceWithStreamingResponse: def __init__(self, versions: VersionsResource) -> None: self._versions = versions - self.retrieve = to_streamed_response_wrapper( - versions.retrieve, - ) self.list = to_streamed_response_wrapper( versions.list, ) @@ -859,15 +856,15 @@ def __init__(self, versions: VersionsResource) -> None: self.create_training = to_streamed_response_wrapper( versions.create_training, ) + self.get = to_streamed_response_wrapper( + versions.get, + ) class AsyncVersionsResourceWithStreamingResponse: def __init__(self, versions: AsyncVersionsResource) -> None: self._versions = versions - self.retrieve = async_to_streamed_response_wrapper( - versions.retrieve, - ) self.list = async_to_streamed_response_wrapper( versions.list, ) @@ -877,3 +874,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None: self.create_training = async_to_streamed_response_wrapper( versions.create_training, ) + self.get = async_to_streamed_response_wrapper( + versions.get, + ) diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py index d6a1245..c03991d 100644 --- a/src/replicate/resources/predictions.py +++ b/src/replicate/resources/predictions.py @@ -189,108 +189,6 @@ def create( cast_to=Prediction, ) - def retrieve( - self, - prediction_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Prediction: - """ - Get the current state of a prediction. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu - ``` - - The response will be the prediction object: - - ```json - { - "id": "gm3qorzdhgbfurvjtvhg6dckhu", - "model": "replicate/hello-world", - "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "input": { - "text": "Alice" - }, - "logs": "", - "output": "hello Alice", - "error": null, - "status": "succeeded", - "created_at": "2023-09-08T16:19:34.765994Z", - "data_removed": false, - "started_at": "2023-09-08T16:19:34.779176Z", - "completed_at": "2023-09-08T16:19:34.791859Z", - "metrics": { - "predict_time": 0.012683 - }, - "urls": { - "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", - "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" - } - } - ``` - - `status` will be one of: - - - `starting`: the prediction is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the prediction. - - `processing`: the `predict()` method of the model is currently running. - - `succeeded`: the prediction completed successfully. - - `failed`: the prediction encountered an error during processing. - - `canceled`: the prediction was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - prediction. - - Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the prediction used while running. - It won't include time waiting for the prediction to start. - - All input parameters, output values, and logs are automatically removed after an - hour, by default, for predictions created through the API. - - You must save a copy of any data or files in the output if you'd like to - continue using them. The `output` key will still be present, but it's value will - be `null` after the output has been removed. - - Output files are served by `replicate.delivery` and its subdomains. If you use - an allow list of external domains for your assets, add `replicate.delivery` and - `*.replicate.delivery` to it. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not prediction_id: - raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") - return self._get( - f"/predictions/{prediction_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Prediction, - ) - def list( self, *, @@ -440,6 +338,108 @@ def cancel( cast_to=NoneType, ) + def get( + self, + prediction_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Prediction: + """ + Get the current state of a prediction. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu + ``` + + The response will be the prediction object: + + ```json + { + "id": "gm3qorzdhgbfurvjtvhg6dckhu", + "model": "replicate/hello-world", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "input": { + "text": "Alice" + }, + "logs": "", + "output": "hello Alice", + "error": null, + "status": "succeeded", + "created_at": "2023-09-08T16:19:34.765994Z", + "data_removed": false, + "started_at": "2023-09-08T16:19:34.779176Z", + "completed_at": "2023-09-08T16:19:34.791859Z", + "metrics": { + "predict_time": 0.012683 + }, + "urls": { + "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", + "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" + } + } + ``` + + `status` will be one of: + + - `starting`: the prediction is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the prediction. + - `processing`: the `predict()` method of the model is currently running. + - `succeeded`: the prediction completed successfully. + - `failed`: the prediction encountered an error during processing. + - `canceled`: the prediction was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + prediction. + + Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the prediction used while running. + It won't include time waiting for the prediction to start. + + All input parameters, output values, and logs are automatically removed after an + hour, by default, for predictions created through the API. + + You must save a copy of any data or files in the output if you'd like to + continue using them. The `output` key will still be present, but it's value will + be `null` after the output has been removed. + + Output files are served by `replicate.delivery` and its subdomains. If you use + an allow list of external domains for your assets, add `replicate.delivery` and + `*.replicate.delivery` to it. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not prediction_id: + raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") + return self._get( + f"/predictions/{prediction_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Prediction, + ) + class AsyncPredictionsResource(AsyncAPIResource): @cached_property @@ -600,108 +600,6 @@ async def create( cast_to=Prediction, ) - async def retrieve( - self, - prediction_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Prediction: - """ - Get the current state of a prediction. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu - ``` - - The response will be the prediction object: - - ```json - { - "id": "gm3qorzdhgbfurvjtvhg6dckhu", - "model": "replicate/hello-world", - "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "input": { - "text": "Alice" - }, - "logs": "", - "output": "hello Alice", - "error": null, - "status": "succeeded", - "created_at": "2023-09-08T16:19:34.765994Z", - "data_removed": false, - "started_at": "2023-09-08T16:19:34.779176Z", - "completed_at": "2023-09-08T16:19:34.791859Z", - "metrics": { - "predict_time": 0.012683 - }, - "urls": { - "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", - "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" - } - } - ``` - - `status` will be one of: - - - `starting`: the prediction is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the prediction. - - `processing`: the `predict()` method of the model is currently running. - - `succeeded`: the prediction completed successfully. - - `failed`: the prediction encountered an error during processing. - - `canceled`: the prediction was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - prediction. - - Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the prediction used while running. - It won't include time waiting for the prediction to start. - - All input parameters, output values, and logs are automatically removed after an - hour, by default, for predictions created through the API. - - You must save a copy of any data or files in the output if you'd like to - continue using them. The `output` key will still be present, but it's value will - be `null` after the output has been removed. - - Output files are served by `replicate.delivery` and its subdomains. If you use - an allow list of external domains for your assets, add `replicate.delivery` and - `*.replicate.delivery` to it. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not prediction_id: - raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") - return await self._get( - f"/predictions/{prediction_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Prediction, - ) - def list( self, *, @@ -851,6 +749,108 @@ async def cancel( cast_to=NoneType, ) + async def get( + self, + prediction_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Prediction: + """ + Get the current state of a prediction. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu + ``` + + The response will be the prediction object: + + ```json + { + "id": "gm3qorzdhgbfurvjtvhg6dckhu", + "model": "replicate/hello-world", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "input": { + "text": "Alice" + }, + "logs": "", + "output": "hello Alice", + "error": null, + "status": "succeeded", + "created_at": "2023-09-08T16:19:34.765994Z", + "data_removed": false, + "started_at": "2023-09-08T16:19:34.779176Z", + "completed_at": "2023-09-08T16:19:34.791859Z", + "metrics": { + "predict_time": 0.012683 + }, + "urls": { + "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", + "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" + } + } + ``` + + `status` will be one of: + + - `starting`: the prediction is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the prediction. + - `processing`: the `predict()` method of the model is currently running. + - `succeeded`: the prediction completed successfully. + - `failed`: the prediction encountered an error during processing. + - `canceled`: the prediction was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + prediction. + + Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the prediction used while running. + It won't include time waiting for the prediction to start. + + All input parameters, output values, and logs are automatically removed after an + hour, by default, for predictions created through the API. + + You must save a copy of any data or files in the output if you'd like to + continue using them. The `output` key will still be present, but it's value will + be `null` after the output has been removed. + + Output files are served by `replicate.delivery` and its subdomains. If you use + an allow list of external domains for your assets, add `replicate.delivery` and + `*.replicate.delivery` to it. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not prediction_id: + raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") + return await self._get( + f"/predictions/{prediction_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Prediction, + ) + class PredictionsResourceWithRawResponse: def __init__(self, predictions: PredictionsResource) -> None: @@ -859,15 +859,15 @@ def __init__(self, predictions: PredictionsResource) -> None: self.create = to_raw_response_wrapper( predictions.create, ) - self.retrieve = to_raw_response_wrapper( - predictions.retrieve, - ) self.list = to_raw_response_wrapper( predictions.list, ) self.cancel = to_raw_response_wrapper( predictions.cancel, ) + self.get = to_raw_response_wrapper( + predictions.get, + ) class AsyncPredictionsResourceWithRawResponse: @@ -877,15 +877,15 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None: self.create = async_to_raw_response_wrapper( predictions.create, ) - self.retrieve = async_to_raw_response_wrapper( - predictions.retrieve, - ) self.list = async_to_raw_response_wrapper( predictions.list, ) self.cancel = async_to_raw_response_wrapper( predictions.cancel, ) + self.get = async_to_raw_response_wrapper( + predictions.get, + ) class PredictionsResourceWithStreamingResponse: @@ -895,15 +895,15 @@ def __init__(self, predictions: PredictionsResource) -> None: self.create = to_streamed_response_wrapper( predictions.create, ) - self.retrieve = to_streamed_response_wrapper( - predictions.retrieve, - ) self.list = to_streamed_response_wrapper( predictions.list, ) self.cancel = to_streamed_response_wrapper( predictions.cancel, ) + self.get = to_streamed_response_wrapper( + predictions.get, + ) class AsyncPredictionsResourceWithStreamingResponse: @@ -913,12 +913,12 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None: self.create = async_to_streamed_response_wrapper( predictions.create, ) - self.retrieve = async_to_streamed_response_wrapper( - predictions.retrieve, - ) self.list = async_to_streamed_response_wrapper( predictions.list, ) self.cancel = async_to_streamed_response_wrapper( predictions.cancel, ) + self.get = async_to_streamed_response_wrapper( + predictions.get, + ) diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py index 46c5872..5a357af 100644 --- a/src/replicate/resources/trainings.py +++ b/src/replicate/resources/trainings.py @@ -15,9 +15,9 @@ ) from ..pagination import SyncCursorURLPage, AsyncCursorURLPage from .._base_client import AsyncPaginator, make_request_options +from ..types.training_get_response import TrainingGetResponse from ..types.training_list_response import TrainingListResponse from ..types.training_cancel_response import TrainingCancelResponse -from ..types.training_retrieve_response import TrainingRetrieveResponse __all__ = ["TrainingsResource", "AsyncTrainingsResource"] @@ -42,99 +42,6 @@ def with_streaming_response(self) -> TrainingsResourceWithStreamingResponse: """ return TrainingsResourceWithStreamingResponse(self) - def retrieve( - self, - training_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> TrainingRetrieveResponse: - """ - Get the current state of a training. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga - ``` - - The response will be the training object: - - ```json - { - "completed_at": "2023-09-08T16:41:19.826523Z", - "created_at": "2023-09-08T16:32:57.018467Z", - "error": null, - "id": "zz4ibbonubfz7carwiefibzgga", - "input": { - "input_images": "https://example.com/my-input-images.zip" - }, - "logs": "...", - "metrics": { - "predict_time": 502.713876 - }, - "output": { - "version": "...", - "weights": "..." - }, - "started_at": "2023-09-08T16:32:57.112647Z", - "status": "succeeded", - "urls": { - "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", - "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel" - }, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" - } - ``` - - `status` will be one of: - - - `starting`: the training is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the training. - - `processing`: the `train()` method of the model is currently running. - - `succeeded`: the training completed successfully. - - `failed`: the training encountered an error during processing. - - `canceled`: the training was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - training. - - Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the training used while running. It - won't include time waiting for the training to start. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not training_id: - raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - return self._get( - f"/trainings/{training_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=TrainingRetrieveResponse, - ) - def list( self, *, @@ -252,28 +159,7 @@ def cancel( cast_to=TrainingCancelResponse, ) - -class AsyncTrainingsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers - """ - return AsyncTrainingsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response - """ - return AsyncTrainingsResourceWithStreamingResponse(self) - - async def retrieve( + def get( self, training_id: str, *, @@ -283,7 +169,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> TrainingRetrieveResponse: + ) -> TrainingGetResponse: """ Get the current state of a training. @@ -358,14 +244,35 @@ async def retrieve( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - return await self._get( + return self._get( f"/trainings/{training_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=TrainingRetrieveResponse, + cast_to=TrainingGetResponse, ) + +class AsyncTrainingsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers + """ + return AsyncTrainingsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response + """ + return AsyncTrainingsResourceWithStreamingResponse(self) + def list( self, *, @@ -483,62 +390,155 @@ async def cancel( cast_to=TrainingCancelResponse, ) + async def get( + self, + training_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> TrainingGetResponse: + """ + Get the current state of a training. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga + ``` + + The response will be the training object: + + ```json + { + "completed_at": "2023-09-08T16:41:19.826523Z", + "created_at": "2023-09-08T16:32:57.018467Z", + "error": null, + "id": "zz4ibbonubfz7carwiefibzgga", + "input": { + "input_images": "https://example.com/my-input-images.zip" + }, + "logs": "...", + "metrics": { + "predict_time": 502.713876 + }, + "output": { + "version": "...", + "weights": "..." + }, + "started_at": "2023-09-08T16:32:57.112647Z", + "status": "succeeded", + "urls": { + "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", + "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel" + }, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" + } + ``` + + `status` will be one of: + + - `starting`: the training is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the training. + - `processing`: the `train()` method of the model is currently running. + - `succeeded`: the training completed successfully. + - `failed`: the training encountered an error during processing. + - `canceled`: the training was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + training. + + Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the training used while running. It + won't include time waiting for the training to start. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not training_id: + raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") + return await self._get( + f"/trainings/{training_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=TrainingGetResponse, + ) + class TrainingsResourceWithRawResponse: def __init__(self, trainings: TrainingsResource) -> None: self._trainings = trainings - self.retrieve = to_raw_response_wrapper( - trainings.retrieve, - ) self.list = to_raw_response_wrapper( trainings.list, ) self.cancel = to_raw_response_wrapper( trainings.cancel, ) + self.get = to_raw_response_wrapper( + trainings.get, + ) class AsyncTrainingsResourceWithRawResponse: def __init__(self, trainings: AsyncTrainingsResource) -> None: self._trainings = trainings - self.retrieve = async_to_raw_response_wrapper( - trainings.retrieve, - ) self.list = async_to_raw_response_wrapper( trainings.list, ) self.cancel = async_to_raw_response_wrapper( trainings.cancel, ) + self.get = async_to_raw_response_wrapper( + trainings.get, + ) class TrainingsResourceWithStreamingResponse: def __init__(self, trainings: TrainingsResource) -> None: self._trainings = trainings - self.retrieve = to_streamed_response_wrapper( - trainings.retrieve, - ) self.list = to_streamed_response_wrapper( trainings.list, ) self.cancel = to_streamed_response_wrapper( trainings.cancel, ) + self.get = to_streamed_response_wrapper( + trainings.get, + ) class AsyncTrainingsResourceWithStreamingResponse: def __init__(self, trainings: AsyncTrainingsResource) -> None: self._trainings = trainings - self.retrieve = async_to_streamed_response_wrapper( - trainings.retrieve, - ) self.list = async_to_streamed_response_wrapper( trainings.list, ) self.cancel = async_to_streamed_response_wrapper( trainings.cancel, ) + self.get = async_to_streamed_response_wrapper( + trainings.get, + ) diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index 7d2d8f4..fa8ee5d 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -7,9 +7,11 @@ from .model_create_params import ModelCreateParams as ModelCreateParams from .model_list_response import ModelListResponse as ModelListResponse from .account_list_response import AccountListResponse as AccountListResponse +from .training_get_response import TrainingGetResponse as TrainingGetResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse from .prediction_list_params import PredictionListParams as PredictionListParams from .training_list_response import TrainingListResponse as TrainingListResponse +from .deployment_get_response import DeploymentGetResponse as DeploymentGetResponse from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams from .deployment_list_response import DeploymentListResponse as DeploymentListResponse from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams @@ -17,6 +19,4 @@ from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse -from .training_retrieve_response import TrainingRetrieveResponse as TrainingRetrieveResponse -from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams diff --git a/src/replicate/types/deployment_retrieve_response.py b/src/replicate/types/deployment_get_response.py similarity index 91% rename from src/replicate/types/deployment_retrieve_response.py rename to src/replicate/types/deployment_get_response.py index 2ecc13c..a7281ff 100644 --- a/src/replicate/types/deployment_retrieve_response.py +++ b/src/replicate/types/deployment_get_response.py @@ -6,7 +6,7 @@ from .._models import BaseModel -__all__ = ["DeploymentRetrieveResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"] +__all__ = ["DeploymentGetResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"] class CurrentReleaseConfiguration(BaseModel): @@ -55,7 +55,7 @@ class CurrentRelease(BaseModel): """The ID of the model version used in the release.""" -class DeploymentRetrieveResponse(BaseModel): +class DeploymentGetResponse(BaseModel): current_release: Optional[CurrentRelease] = None name: Optional[str] = None diff --git a/src/replicate/types/training_retrieve_response.py b/src/replicate/types/training_get_response.py similarity index 94% rename from src/replicate/types/training_retrieve_response.py rename to src/replicate/types/training_get_response.py index 55b1173..4169da7 100644 --- a/src/replicate/types/training_retrieve_response.py +++ b/src/replicate/types/training_get_response.py @@ -6,7 +6,7 @@ from .._models import BaseModel -__all__ = ["TrainingRetrieveResponse", "Metrics", "Output", "URLs"] +__all__ = ["TrainingGetResponse", "Metrics", "Output", "URLs"] class Metrics(BaseModel): @@ -30,7 +30,7 @@ class URLs(BaseModel): """URL to get the training details""" -class TrainingRetrieveResponse(BaseModel): +class TrainingGetResponse(BaseModel): id: Optional[str] = None """The unique ID of the training""" diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py index 7e6d583..d1fb7a8 100644 --- a/tests/api_resources/models/test_versions.py +++ b/tests/api_resources/models/test_versions.py @@ -17,70 +17,6 @@ class TestVersions: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - version = client.models.versions.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) - assert version is None - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = response.parse() - assert version is None - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.models.versions.with_streaming_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - version = response.parse() - assert version is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="", - model_name="model_name", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="", - model_owner="model_owner", - model_name="model_name", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -287,14 +223,10 @@ def test_path_params_create_training(self, client: ReplicateClient) -> None: input={}, ) - -class TestAsyncVersions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - version = await async_client.models.versions.retrieve( + def test_method_get(self, client: ReplicateClient) -> None: + version = client.models.versions.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -303,8 +235,8 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None @pytest.mark.skip() @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.models.versions.with_raw_response.retrieve( + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -312,13 +244,13 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) - assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = await response.parse() + version = response.parse() assert version is None @pytest.mark.skip() @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.models.versions.with_streaming_response.retrieve( + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.models.versions.with_streaming_response.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -326,35 +258,39 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = await response.parse() + version = response.parse() assert version is None assert cast(Any, response.is_closed) is True @pytest.mark.skip() @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: + def test_path_params_get(self, client: ReplicateClient) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="version_id", model_owner="", model_name="model_name", ) with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", model_name="", ) with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="", model_owner="model_owner", model_name="model_name", ) + +class TestAsyncVersions: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -560,3 +496,67 @@ async def test_path_params_create_training(self, async_client: AsyncReplicateCli destination="destination", input={}, ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + version = await async_client.models.versions.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) + assert version is None + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + version = await response.parse() + assert version is None + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.models.versions.with_streaming_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + version = await response.parse() + assert version is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="", + model_name="model_name", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="", + model_owner="model_owner", + model_name="model_name", + ) diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py index b834241..3a01dcb 100644 --- a/tests/api_resources/test_deployments.py +++ b/tests/api_resources/test_deployments.py @@ -10,10 +10,10 @@ from replicate import ReplicateClient, AsyncReplicateClient from tests.utils import assert_matches_type from replicate.types import ( + DeploymentGetResponse, DeploymentListResponse, DeploymentCreateResponse, DeploymentUpdateResponse, - DeploymentRetrieveResponse, ) from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage @@ -72,58 +72,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - deployment = client.deployments.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - deployment = response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.deployments.with_streaming_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - deployment = response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): - client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): - client.deployments.with_raw_response.retrieve( - deployment_name="", - deployment_owner="deployment_owner", - ) - @pytest.mark.skip() @parametrize def test_method_update(self, client: ReplicateClient) -> None: @@ -269,6 +217,58 @@ def test_path_params_delete(self, client: ReplicateClient) -> None: deployment_owner="deployment_owner", ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + deployment = client.deployments.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + deployment = response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.deployments.with_streaming_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + deployment = response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): + client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): + client.deployments.with_raw_response.get( + deployment_name="", + deployment_owner="deployment_owner", + ) + @pytest.mark.skip() @parametrize def test_method_list_em_all(self, client: ReplicateClient) -> None: @@ -350,58 +350,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - deployment = await async_client.deployments.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - deployment = await response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.deployments.with_streaming_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - deployment = await response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): - await async_client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): - await async_client.deployments.with_raw_response.retrieve( - deployment_name="", - deployment_owner="deployment_owner", - ) - @pytest.mark.skip() @parametrize async def test_method_update(self, async_client: AsyncReplicateClient) -> None: @@ -547,6 +495,58 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N deployment_owner="deployment_owner", ) + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + deployment = await async_client.deployments.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + deployment = await response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.deployments.with_streaming_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + deployment = await response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): + await async_client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): + await async_client.deployments.with_raw_response.get( + deployment_name="", + deployment_owner="deployment_owner", + ) + @pytest.mark.skip() @parametrize async def test_method_list_em_all(self, async_client: AsyncReplicateClient) -> None: diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index f56f00c..a2a9f52 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -77,58 +77,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - model = client.models.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - assert model is None - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - model = response.parse() - assert model is None - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.models.with_streaming_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - model = response.parse() - assert model is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - client.models.with_raw_response.retrieve( - model_name="", - model_owner="model_owner", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -280,6 +228,58 @@ def test_path_params_create_prediction(self, client: ReplicateClient) -> None: input={}, ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + model = client.models.get( + model_name="model_name", + model_owner="model_owner", + ) + assert model is None + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.models.with_raw_response.get( + model_name="model_name", + model_owner="model_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert model is None + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.models.with_streaming_response.get( + model_name="model_name", + model_owner="model_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + client.models.with_raw_response.get( + model_name="model_name", + model_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + client.models.with_raw_response.get( + model_name="", + model_owner="model_owner", + ) + class TestAsyncModels: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -343,58 +343,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - model = await async_client.models.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - assert model is None - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - model = await response.parse() - assert model is None - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.models.with_streaming_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - model = await response.parse() - assert model is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - await async_client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - await async_client.models.with_raw_response.retrieve( - model_name="", - model_owner="model_owner", - ) - @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -545,3 +493,55 @@ async def test_path_params_create_prediction(self, async_client: AsyncReplicateC model_owner="model_owner", input={}, ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + model = await async_client.models.get( + model_name="model_name", + model_owner="model_owner", + ) + assert model is None + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.models.with_raw_response.get( + model_name="model_name", + model_owner="model_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert model is None + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.models.with_streaming_response.get( + model_name="model_name", + model_owner="model_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + await async_client.models.with_raw_response.get( + model_name="model_name", + model_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + await async_client.models.with_raw_response.get( + model_name="", + model_owner="model_owner", + ) diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py index 8fc4142..51bd80d 100644 --- a/tests/api_resources/test_predictions.py +++ b/tests/api_resources/test_predictions.py @@ -69,48 +69,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - prediction = client.predictions.retrieve( - "prediction_id", - ) - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.predictions.with_raw_response.retrieve( - "prediction_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prediction = response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.predictions.with_streaming_response.retrieve( - "prediction_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - prediction = response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): - client.predictions.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -190,6 +148,48 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: "", ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + prediction = client.predictions.get( + "prediction_id", + ) + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.predictions.with_raw_response.get( + "prediction_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + prediction = response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.predictions.with_streaming_response.get( + "prediction_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prediction = response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): + client.predictions.with_raw_response.get( + "", + ) + class TestAsyncPredictions: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -244,48 +244,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - prediction = await async_client.predictions.retrieve( - "prediction_id", - ) - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.predictions.with_raw_response.retrieve( - "prediction_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prediction = await response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.predictions.with_streaming_response.retrieve( - "prediction_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - prediction = await response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): - await async_client.predictions.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -364,3 +322,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N await async_client.predictions.with_raw_response.cancel( "", ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + prediction = await async_client.predictions.get( + "prediction_id", + ) + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.predictions.with_raw_response.get( + "prediction_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + prediction = await response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.predictions.with_streaming_response.get( + "prediction_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prediction = await response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): + await async_client.predictions.with_raw_response.get( + "", + ) diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py index 68b1a85..f2dadb1 100644 --- a/tests/api_resources/test_trainings.py +++ b/tests/api_resources/test_trainings.py @@ -9,7 +9,7 @@ from replicate import ReplicateClient, AsyncReplicateClient from tests.utils import assert_matches_type -from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingRetrieveResponse +from replicate.types import TrainingGetResponse, TrainingListResponse, TrainingCancelResponse from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -18,48 +18,6 @@ class TestTrainings: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - training = client.trainings.retrieve( - "training_id", - ) - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.trainings.with_raw_response.retrieve( - "training_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = response.parse() - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.trainings.with_streaming_response.retrieve( - "training_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - training = response.parse() - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): - client.trainings.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -130,52 +88,52 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: "", ) - -class TestAsyncTrainings: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - training = await async_client.trainings.retrieve( + def test_method_get(self, client: ReplicateClient) -> None: + training = client.trainings.get( "training_id", ) - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) + assert_matches_type(TrainingGetResponse, training, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.trainings.with_raw_response.retrieve( + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.trainings.with_raw_response.get( "training_id", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = await response.parse() - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) + training = response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) @pytest.mark.skip() @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.trainings.with_streaming_response.retrieve( + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.trainings.with_streaming_response.get( "training_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = await response.parse() - assert_matches_type(TrainingRetrieveResponse, training, path=["response"]) + training = response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @pytest.mark.skip() @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: + def test_path_params_get(self, client: ReplicateClient) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): - await async_client.trainings.with_raw_response.retrieve( + client.trainings.with_raw_response.get( "", ) + +class TestAsyncTrainings: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -245,3 +203,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N await async_client.trainings.with_raw_response.cancel( "", ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + training = await async_client.trainings.get( + "training_id", + ) + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.trainings.with_raw_response.get( + "training_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + training = await response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.trainings.with_streaming_response.get( + "training_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + training = await response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): + await async_client.trainings.with_raw_response.get( + "", + ) From fc34c6d4fc36a41441ab8417f85343e640b53b76 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Sat, 19 Apr 2025 03:40:50 +0000 Subject: [PATCH 06/12] chore(internal): update models test --- tests/test_models.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 25f3d43..cf8173c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -492,12 +492,15 @@ class Model(BaseModel): resource_id: Optional[str] = None m = Model.construct() + assert m.resource_id is None assert "resource_id" not in m.model_fields_set m = Model.construct(resource_id=None) + assert m.resource_id is None assert "resource_id" in m.model_fields_set m = Model.construct(resource_id="foo") + assert m.resource_id == "foo" assert "resource_id" in m.model_fields_set From 1bad4d3d3676a323032f37f0195ff640fcce3458 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:34:13 +0000 Subject: [PATCH 07/12] chore(ci): add timeout thresholds for CI jobs --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81f6dc2..04b083c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,7 @@ on: jobs: lint: + timeout-minutes: 10 name: lint runs-on: ubuntu-latest steps: @@ -30,6 +31,7 @@ jobs: run: ./scripts/lint test: + timeout-minutes: 10 name: test runs-on: ubuntu-latest steps: From 4cdf515372a9e936c3a18afd24a444a778b1f7f5 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:34:45 +0000 Subject: [PATCH 08/12] chore(internal): import reformatting --- src/replicate/_client.py | 5 +---- src/replicate/resources/deployments/deployments.py | 5 +---- src/replicate/resources/deployments/predictions.py | 6 +----- src/replicate/resources/models/models.py | 6 +----- src/replicate/resources/models/versions.py | 5 +---- src/replicate/resources/predictions.py | 6 +----- 6 files changed, 6 insertions(+), 27 deletions(-) diff --git a/src/replicate/_client.py b/src/replicate/_client.py index b59e805..6afa0ec 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -19,10 +19,7 @@ ProxiesTypes, RequestOptions, ) -from ._utils import ( - is_given, - get_async_library, -) +from ._utils import is_given, get_async_library from ._version import __version__ from .resources import accounts, hardware, trainings, collections, predictions from ._streaming import Stream as Stream, AsyncStream as AsyncStream diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py index 7e2fac7..565a1c6 100644 --- a/src/replicate/resources/deployments/deployments.py +++ b/src/replicate/resources/deployments/deployments.py @@ -6,10 +6,7 @@ from ...types import deployment_create_params, deployment_update_params from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py index 7177332..c94fff0 100644 --- a/src/replicate/resources/deployments/predictions.py +++ b/src/replicate/resources/deployments/predictions.py @@ -8,11 +8,7 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from ..._utils import maybe_transform, strip_not_given, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 6d8c6eb..435101f 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -9,11 +9,7 @@ from ...types import model_create_params, model_create_prediction_params from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from ..._utils import maybe_transform, strip_not_given, async_maybe_transform from .versions import ( VersionsResource, AsyncVersionsResource, diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py index 2065fe0..308d04b 100644 --- a/src/replicate/resources/models/versions.py +++ b/src/replicate/resources/models/versions.py @@ -8,10 +8,7 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py index c03991d..abe0ca1 100644 --- a/src/replicate/resources/predictions.py +++ b/src/replicate/resources/predictions.py @@ -10,11 +10,7 @@ from ..types import prediction_list_params, prediction_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from .._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from .._utils import maybe_transform, strip_not_given, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( From 2918ebad39df868485fed02a2d0020bef72d24b9 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:35:56 +0000 Subject: [PATCH 09/12] chore(internal): fix list file params --- src/replicate/_utils/_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/replicate/_utils/_utils.py b/src/replicate/_utils/_utils.py index e5811bb..ea3cf3f 100644 --- a/src/replicate/_utils/_utils.py +++ b/src/replicate/_utils/_utils.py @@ -72,8 +72,16 @@ def _extract_items( from .._files import assert_is_file_content # We have exhausted the path, return the entry we found. - assert_is_file_content(obj, key=flattened_key) assert flattened_key is not None + + if is_list(obj): + files: list[tuple[str, FileTypes]] = [] + for entry in obj: + assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "") + files.append((flattened_key + "[]", cast(FileTypes, entry))) + return files + + assert_is_file_content(obj, key=flattened_key) return [(flattened_key, cast(FileTypes, obj))] index += 1 From 75005e11045385d0596911bbbbb062207450bd14 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:36:31 +0000 Subject: [PATCH 10/12] chore(internal): refactor retries to not use recursion --- src/replicate/_base_client.py | 414 ++++++++++++++-------------------- 1 file changed, 175 insertions(+), 239 deletions(-) diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index fb06997..84db2c9 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -437,8 +437,7 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0 headers = httpx.Headers(headers_dict) idempotency_header = self._idempotency_header - if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: - options.idempotency_key = options.idempotency_key or self._idempotency_key() + if idempotency_header and options.idempotency_key and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key # Don't set these headers if they were already set or removed by the caller. We check @@ -903,7 +902,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[True], stream_cls: Type[_StreamT], @@ -914,7 +912,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[False] = False, ) -> ResponseT: ... @@ -924,7 +921,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: Type[_StreamT] | None = None, @@ -934,125 +930,109 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) + cast_to = self._maybe_override_cast_to(cast_to, options) - def _request( - self, - *, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - retries_taken: int, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) - - cast_to = self._maybe_override_cast_to(cast_to, options) - options = self._prepare_options(options) - - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - self._prepare_request(request) - - if options.idempotency_key: + if input_options.idempotency_key is None and input_options.method.lower() != "get": # ensure the idempotency key is reused between requests - input_options.idempotency_key = options.idempotency_key + input_options.idempotency_key = self._idempotency_key() - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - log.debug("Sending HTTP Request: %s %s", request.method, request.url) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = self._prepare_options(options) - try: - response = self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + self._prepare_request(request) - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, + response = None + try: + response = self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err - - log.debug( - 'HTTP Response: %s %s "%i %s" %s', - request.method, - request.url, - response.status_code, - response.reason_phrase, - response.headers, - ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + err.response.close() + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - err.response.close() - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + err.response.read() - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - err.response.read() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return self._process_response( cast_to=cast_to, options=options, @@ -1062,37 +1042,20 @@ def _request( retries_taken=retries_taken, ) - def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) - # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a - # different thread if necessary. time.sleep(timeout) - return self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - def _process_response( self, *, @@ -1436,7 +1399,6 @@ async def request( options: FinalRequestOptions, *, stream: Literal[False] = False, - remaining_retries: Optional[int] = None, ) -> ResponseT: ... @overload @@ -1447,7 +1409,6 @@ async def request( *, stream: Literal[True], stream_cls: type[_AsyncStreamT], - remaining_retries: Optional[int] = None, ) -> _AsyncStreamT: ... @overload @@ -1458,7 +1419,6 @@ async def request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: ... async def request( @@ -1468,120 +1428,111 @@ async def request( *, stream: bool = False, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, - ) -> ResponseT | _AsyncStreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return await self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) - - async def _request( - self, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - *, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - retries_taken: int, ) -> ResponseT | _AsyncStreamT: if self._platform is None: # `get_platform` can make blocking IO calls so we # execute it earlier while we are in an async context self._platform = await asyncify(get_platform)() + cast_to = self._maybe_override_cast_to(cast_to, options) + # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) - - cast_to = self._maybe_override_cast_to(cast_to, options) - options = await self._prepare_options(options) - - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - await self._prepare_request(request) - - if options.idempotency_key: + if input_options.idempotency_key is None and input_options.method.lower() != "get": # ensure the idempotency key is reused between requests - input_options.idempotency_key = options.idempotency_key + input_options.idempotency_key = self._idempotency_key() - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - try: - response = await self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = await self._prepare_options(options) - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + await self._prepare_request(request) - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err + response = None + try: + response = await self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, + ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug( - 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase - ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + await err.response.aclose() + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - await err.response.aclose() - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + await err.response.aread() - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - await err.response.aread() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return await self._process_response( cast_to=cast_to, options=options, @@ -1591,35 +1542,20 @@ async def _request( retries_taken=retries_taken, ) - async def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - ) -> ResponseT | _AsyncStreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + async def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) - return await self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - async def _process_response( self, *, From c907599a6736e781f3f80062eb4d03ed92f03403 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:36:56 +0000 Subject: [PATCH 11/12] fix(pydantic v1): more robust ModelField.annotation check --- src/replicate/_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/replicate/_models.py b/src/replicate/_models.py index 58b9263..798956f 100644 --- a/src/replicate/_models.py +++ b/src/replicate/_models.py @@ -626,8 +626,8 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, # Note: if one variant defines an alias then they all should discriminator_alias = field_info.alias - if field_info.annotation and is_literal_type(field_info.annotation): - for entry in get_args(field_info.annotation): + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): if isinstance(entry, str): mapping[entry] = variant From 8d82164749ccdcb0e1fd9700659281a3db5abfb2 Mon Sep 17 00:00:00 2001 From: "stainless-app[bot]" <142633134+stainless-app[bot]@users.noreply.github.com> Date: Wed, 23 Apr 2025 04:37:12 +0000 Subject: [PATCH 12/12] release: 0.1.0-alpha.3 --- .release-please-manifest.json | 2 +- CHANGELOG.md | 29 +++++++++++++++++++++++++++++ pyproject.toml | 2 +- src/replicate/_version.py | 2 +- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/.release-please-manifest.json b/.release-please-manifest.json index f14b480..aaf968a 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.2" + ".": "0.1.0-alpha.3" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index d51ac7f..6bde67f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Changelog +## 0.1.0-alpha.3 (2025-04-23) + +Full Changelog: [v0.1.0-alpha.2...v0.1.0-alpha.3](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.2...v0.1.0-alpha.3) + +### ⚠ BREAKING CHANGES + +* **api:** use correct env var for bearer token + +### Features + +* **api:** api update ([7ebd598](https://github.com/replicate/replicate-python-stainless/commit/7ebd59873181c74dbaa035ac599abcbbefb3ee62)) + + +### Bug Fixes + +* **api:** use correct env var for bearer token ([00eab77](https://github.com/replicate/replicate-python-stainless/commit/00eab7702f8f2699ce9b3070f23202278ac21855)) +* **pydantic v1:** more robust ModelField.annotation check ([c907599](https://github.com/replicate/replicate-python-stainless/commit/c907599a6736e781f3f80062eb4d03ed92f03403)) + + +### Chores + +* **ci:** add timeout thresholds for CI jobs ([1bad4d3](https://github.com/replicate/replicate-python-stainless/commit/1bad4d3d3676a323032f37f0195ff640fcce3458)) +* **internal:** base client updates ([c1d6ed5](https://github.com/replicate/replicate-python-stainless/commit/c1d6ed59ed0f06012922ec6d0bae376852523d81)) +* **internal:** bump pyright version ([f1e4d14](https://github.com/replicate/replicate-python-stainless/commit/f1e4d140104ff317b94cb2dd88ec850a9b8bce54)) +* **internal:** fix list file params ([2918eba](https://github.com/replicate/replicate-python-stainless/commit/2918ebad39df868485fed02a2d0020bef72d24b9)) +* **internal:** import reformatting ([4cdf515](https://github.com/replicate/replicate-python-stainless/commit/4cdf515372a9e936c3a18afd24a444a778b1f7f5)) +* **internal:** refactor retries to not use recursion ([75005e1](https://github.com/replicate/replicate-python-stainless/commit/75005e11045385d0596911bbbbb062207450bd14)) +* **internal:** update models test ([fc34c6d](https://github.com/replicate/replicate-python-stainless/commit/fc34c6d4fc36a41441ab8417f85343e640b53b76)) + ## 0.1.0-alpha.2 (2025-04-16) Full Changelog: [v0.1.0-alpha.1...v0.1.0-alpha.2](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.1...v0.1.0-alpha.2) diff --git a/pyproject.toml b/pyproject.toml index 8eb0e54..4376f13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate-stainless" -version = "0.1.0-alpha.2" +version = "0.1.0-alpha.3" description = "The official Python library for the replicate-client API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/replicate/_version.py b/src/replicate/_version.py index dfeb99c..5b5e6c4 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "0.1.0-alpha.2" # x-release-please-version +__version__ = "0.1.0-alpha.3" # x-release-please-version