From 00eab7702f8f2699ce9b3070f23202278ac21855 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 16 Apr 2025 19:14:13 +0000
Subject: [PATCH 01/12] fix(api)!: use correct env var for bearer token
The correct env var to use is REPLICATE_API_TOKEN, not REPLICATE_CLIENT_BEARER_TOKEN
---
.stats.yml | 2 +-
README.md | 10 +++-------
src/replicate/_client.py | 12 ++++++------
tests/test_client.py | 4 ++--
4 files changed, 12 insertions(+), 16 deletions(-)
diff --git a/.stats.yml b/.stats.yml
index 91aadf3..263f930 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 27
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml
openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3
-config_hash: 810de4c2eee1a7649263cff01f00da7c
+config_hash: 64fd304bf0ff077a74053ce79a255106
diff --git a/README.md b/README.md
index 0015b4c..23b57ae 100644
--- a/README.md
+++ b/README.md
@@ -28,9 +28,7 @@ import os
from replicate import ReplicateClient
client = ReplicateClient(
- bearer_token=os.environ.get(
- "REPLICATE_CLIENT_BEARER_TOKEN"
- ), # This is the default and can be omitted
+ bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
accounts = client.accounts.list()
@@ -39,7 +37,7 @@ print(accounts.type)
While you can provide a `bearer_token` keyword argument,
we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
-to add `REPLICATE_CLIENT_BEARER_TOKEN="My Bearer Token"` to your `.env` file
+to add `REPLICATE_API_TOKEN="My Bearer Token"` to your `.env` file
so that your Bearer Token is not stored in source control.
## Async usage
@@ -52,9 +50,7 @@ import asyncio
from replicate import AsyncReplicateClient
client = AsyncReplicateClient(
- bearer_token=os.environ.get(
- "REPLICATE_CLIENT_BEARER_TOKEN"
- ), # This is the default and can be omitted
+ bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
diff --git a/src/replicate/_client.py b/src/replicate/_client.py
index 06e86ff..b59e805 100644
--- a/src/replicate/_client.py
+++ b/src/replicate/_client.py
@@ -88,13 +88,13 @@ def __init__(
) -> None:
"""Construct a new synchronous ReplicateClient client instance.
- This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided.
+ This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
if bearer_token is None:
- bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN")
+ bearer_token = os.environ.get("REPLICATE_API_TOKEN")
if bearer_token is None:
raise ReplicateClientError(
- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable"
+ "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
)
self.bearer_token = bearer_token
@@ -270,13 +270,13 @@ def __init__(
) -> None:
"""Construct a new async AsyncReplicateClient client instance.
- This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided.
+ This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
if bearer_token is None:
- bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN")
+ bearer_token = os.environ.get("REPLICATE_API_TOKEN")
if bearer_token is None:
raise ReplicateClientError(
- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable"
+ "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
)
self.bearer_token = bearer_token
diff --git a/tests/test_client.py b/tests/test_client.py
index bd5b999..279675f 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -346,7 +346,7 @@ def test_validate_headers(self) -> None:
assert request.headers.get("Authorization") == f"Bearer {bearer_token}"
with pytest.raises(ReplicateClientError):
- with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}):
+ with update_env(**{"REPLICATE_API_TOKEN": Omit()}):
client2 = ReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True)
_ = client2
@@ -1126,7 +1126,7 @@ def test_validate_headers(self) -> None:
assert request.headers.get("Authorization") == f"Bearer {bearer_token}"
with pytest.raises(ReplicateClientError):
- with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}):
+ with update_env(**{"REPLICATE_API_TOKEN": Omit()}):
client2 = AsyncReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True)
_ = client2
From 7ebd59873181c74dbaa035ac599abcbbefb3ee62 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 16 Apr 2025 23:30:55 +0000
Subject: [PATCH 02/12] feat(api): api update
---
.stats.yml | 4 +-
api.md | 20 ++++-
src/replicate/resources/models/versions.py | 11 ++-
src/replicate/resources/trainings.py | 46 ++++++------
src/replicate/types/__init__.py | 3 +
.../types/deployment_list_response.py | 6 +-
src/replicate/types/models/__init__.py | 1 +
.../version_create_training_response.py | 74 +++++++++++++++++++
src/replicate/types/prediction_output.py | 2 +-
.../types/training_cancel_response.py | 74 +++++++++++++++++++
src/replicate/types/training_list_response.py | 74 +++++++++++++++++++
.../types/training_retrieve_response.py | 74 +++++++++++++++++++
tests/api_resources/models/test_versions.py | 18 +++--
tests/api_resources/test_trainings.py | 39 +++++-----
14 files changed, 379 insertions(+), 67 deletions(-)
create mode 100644 src/replicate/types/models/version_create_training_response.py
create mode 100644 src/replicate/types/training_cancel_response.py
create mode 100644 src/replicate/types/training_list_response.py
create mode 100644 src/replicate/types/training_retrieve_response.py
diff --git a/.stats.yml b/.stats.yml
index 263f930..1448ef2 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 27
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml
-openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
+openapi_spec_hash: 4423bf747e228484547b441468a9f156
config_hash: 64fd304bf0ff077a74053ce79a255106
diff --git a/api.md b/api.md
index 60e91e6..509868d 100644
--- a/api.md
+++ b/api.md
@@ -75,12 +75,18 @@ Methods:
## Versions
+Types:
+
+```python
+from replicate.types.models import VersionCreateTrainingResponse
+```
+
Methods:
- client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None
- client.models.versions.list(model_name, \*, model_owner) -> None
- client.models.versions.delete(version_id, \*, model_owner, model_name) -> None
-- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> None
+- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse
# Predictions
@@ -99,11 +105,17 @@ Methods:
# Trainings
+Types:
+
+```python
+from replicate.types import TrainingRetrieveResponse, TrainingListResponse, TrainingCancelResponse
+```
+
Methods:
-- client.trainings.retrieve(training_id) -> None
-- client.trainings.list() -> None
-- client.trainings.cancel(training_id) -> None
+- client.trainings.retrieve(training_id) -> TrainingRetrieveResponse
+- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
+- client.trainings.cancel(training_id) -> TrainingCancelResponse
# Webhooks
diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py
index eb411ef..203bb5d 100644
--- a/src/replicate/resources/models/versions.py
+++ b/src/replicate/resources/models/versions.py
@@ -22,6 +22,7 @@
)
from ..._base_client import make_request_options
from ...types.models import version_create_training_params
+from ...types.models.version_create_training_response import VersionCreateTrainingResponse
__all__ = ["VersionsResource", "AsyncVersionsResource"]
@@ -281,7 +282,7 @@ def create_training(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> VersionCreateTrainingResponse:
"""
Start a new training of the model version you specify.
@@ -398,7 +399,6 @@ def create_training(
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return self._post(
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
body=maybe_transform(
@@ -413,7 +413,7 @@ def create_training(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=VersionCreateTrainingResponse,
)
@@ -672,7 +672,7 @@ async def create_training(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> VersionCreateTrainingResponse:
"""
Start a new training of the model version you specify.
@@ -789,7 +789,6 @@ async def create_training(
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return await self._post(
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
body=await async_maybe_transform(
@@ -804,7 +803,7 @@ async def create_training(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=VersionCreateTrainingResponse,
)
diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py
index df64c19..46c5872 100644
--- a/src/replicate/resources/trainings.py
+++ b/src/replicate/resources/trainings.py
@@ -4,7 +4,7 @@
import httpx
-from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
@@ -13,7 +13,11 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
-from .._base_client import make_request_options
+from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
+from .._base_client import AsyncPaginator, make_request_options
+from ..types.training_list_response import TrainingListResponse
+from ..types.training_cancel_response import TrainingCancelResponse
+from ..types.training_retrieve_response import TrainingRetrieveResponse
__all__ = ["TrainingsResource", "AsyncTrainingsResource"]
@@ -48,7 +52,7 @@ def retrieve(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingRetrieveResponse:
"""
Get the current state of a training.
@@ -123,13 +127,12 @@ def retrieve(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return self._get(
f"/trainings/{training_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingRetrieveResponse,
)
def list(
@@ -141,7 +144,7 @@ def list(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> SyncCursorURLPage[TrainingListResponse]:
"""
Get a paginated list of all trainings created by the user or organization
associated with the provided API token.
@@ -207,13 +210,13 @@ def list(
`version` will be the unique ID of model version used to create the training.
"""
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
+ return self._get_api_list(
"/trainings",
+ page=SyncCursorURLPage[TrainingListResponse],
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ model=TrainingListResponse,
)
def cancel(
@@ -226,7 +229,7 @@ def cancel(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingCancelResponse:
"""
Cancel a training
@@ -241,13 +244,12 @@ def cancel(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return self._post(
f"/trainings/{training_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingCancelResponse,
)
@@ -281,7 +283,7 @@ async def retrieve(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingRetrieveResponse:
"""
Get the current state of a training.
@@ -356,16 +358,15 @@ async def retrieve(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return await self._get(
f"/trainings/{training_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingRetrieveResponse,
)
- async def list(
+ def list(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -374,7 +375,7 @@ async def list(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> AsyncPaginator[TrainingListResponse, AsyncCursorURLPage[TrainingListResponse]]:
"""
Get a paginated list of all trainings created by the user or organization
associated with the provided API token.
@@ -440,13 +441,13 @@ async def list(
`version` will be the unique ID of model version used to create the training.
"""
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
+ return self._get_api_list(
"/trainings",
+ page=AsyncCursorURLPage[TrainingListResponse],
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ model=TrainingListResponse,
)
async def cancel(
@@ -459,7 +460,7 @@ async def cancel(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingCancelResponse:
"""
Cancel a training
@@ -474,13 +475,12 @@ async def cancel(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return await self._post(
f"/trainings/{training_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingCancelResponse,
)
diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py
index e2b3c58..7d2d8f4 100644
--- a/src/replicate/types/__init__.py
+++ b/src/replicate/types/__init__.py
@@ -9,11 +9,14 @@
from .account_list_response import AccountListResponse as AccountListResponse
from .hardware_list_response import HardwareListResponse as HardwareListResponse
from .prediction_list_params import PredictionListParams as PredictionListParams
+from .training_list_response import TrainingListResponse as TrainingListResponse
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
+from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
+from .training_retrieve_response import TrainingRetrieveResponse as TrainingRetrieveResponse
from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse
from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams
diff --git a/src/replicate/types/deployment_list_response.py b/src/replicate/types/deployment_list_response.py
index 9f64487..2606d38 100644
--- a/src/replicate/types/deployment_list_response.py
+++ b/src/replicate/types/deployment_list_response.py
@@ -49,11 +49,7 @@ class CurrentRelease(BaseModel):
"""The model identifier string in the format of `{model_owner}/{model_name}`."""
number: Optional[int] = None
- """The release number.
-
- This is an auto-incrementing integer that starts at 1, and is set automatically
- when a deployment is created.
- """
+ """The release number."""
version: Optional[str] = None
"""The ID of the model version used in the release."""
diff --git a/src/replicate/types/models/__init__.py b/src/replicate/types/models/__init__.py
index 98e1a3a..2d89bc6 100644
--- a/src/replicate/types/models/__init__.py
+++ b/src/replicate/types/models/__init__.py
@@ -3,3 +3,4 @@
from __future__ import annotations
from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams
+from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse
diff --git a/src/replicate/types/models/version_create_training_response.py b/src/replicate/types/models/version_create_training_response.py
new file mode 100644
index 0000000..5b005fe
--- /dev/null
+++ b/src/replicate/types/models/version_create_training_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+
+__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class VersionCreateTrainingResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/prediction_output.py b/src/replicate/types/prediction_output.py
index 2c46100..8b320b5 100644
--- a/src/replicate/types/prediction_output.py
+++ b/src/replicate/types/prediction_output.py
@@ -6,5 +6,5 @@
__all__ = ["PredictionOutput"]
PredictionOutput: TypeAlias = Union[
- Optional[Dict[str, object]], Optional[List[object]], Optional[str], Optional[float], Optional[bool]
+ Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[str], Optional[float], Optional[bool]
]
diff --git a/src/replicate/types/training_cancel_response.py b/src/replicate/types/training_cancel_response.py
new file mode 100644
index 0000000..9af512b
--- /dev/null
+++ b/src/replicate/types/training_cancel_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingCancelResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingCancelResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/training_list_response.py b/src/replicate/types/training_list_response.py
new file mode 100644
index 0000000..02adf3a
--- /dev/null
+++ b/src/replicate/types/training_list_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingListResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingListResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/training_retrieve_response.py b/src/replicate/types/training_retrieve_response.py
new file mode 100644
index 0000000..55b1173
--- /dev/null
+++ b/src/replicate/types/training_retrieve_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingRetrieveResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingRetrieveResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py
index 295205e..7e6d583 100644
--- a/tests/api_resources/models/test_versions.py
+++ b/tests/api_resources/models/test_versions.py
@@ -8,6 +8,8 @@
import pytest
from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types.models import VersionCreateTrainingResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -205,7 +207,7 @@ def test_method_create_training(self, client: ReplicateClient) -> None:
destination="destination",
input={},
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -219,7 +221,7 @@ def test_method_create_training_with_all_params(self, client: ReplicateClient) -
webhook="webhook",
webhook_events_filter=["start"],
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -235,7 +237,7 @@ def test_raw_response_create_training(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -251,7 +253,7 @@ def test_streaming_response_create_training(self, client: ReplicateClient) -> No
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -479,7 +481,7 @@ async def test_method_create_training(self, async_client: AsyncReplicateClient)
destination="destination",
input={},
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -493,7 +495,7 @@ async def test_method_create_training_with_all_params(self, async_client: AsyncR
webhook="webhook",
webhook_events_filter=["start"],
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -509,7 +511,7 @@ async def test_raw_response_create_training(self, async_client: AsyncReplicateCl
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = await response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -525,7 +527,7 @@ async def test_streaming_response_create_training(self, async_client: AsyncRepli
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = await response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
assert cast(Any, response.is_closed) is True
diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py
index f59d874..68b1a85 100644
--- a/tests/api_resources/test_trainings.py
+++ b/tests/api_resources/test_trainings.py
@@ -8,6 +8,9 @@
import pytest
from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingRetrieveResponse
+from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -21,7 +24,7 @@ def test_method_retrieve(self, client: ReplicateClient) -> None:
training = client.trainings.retrieve(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -33,7 +36,7 @@ def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -45,7 +48,7 @@ def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -61,7 +64,7 @@ def test_path_params_retrieve(self, client: ReplicateClient) -> None:
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
training = client.trainings.list()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -71,7 +74,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -81,7 +84,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -91,7 +94,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None:
training = client.trainings.cancel(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -103,7 +106,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -115,7 +118,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -137,7 +140,7 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None
training = await async_client.trainings.retrieve(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -149,7 +152,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -161,7 +164,7 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -177,7 +180,7 @@ async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) ->
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
training = await async_client.trainings.list()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -187,7 +190,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -197,7 +200,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -207,7 +210,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None:
training = await async_client.trainings.cancel(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -219,7 +222,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) ->
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -231,7 +234,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
From f1e4d140104ff317b94cb2dd88ec850a9b8bce54 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 17 Apr 2025 03:30:58 +0000
Subject: [PATCH 03/12] chore(internal): bump pyright version
---
pyproject.toml | 2 +-
requirements-dev.lock | 2 +-
src/replicate/_base_client.py | 6 +++++-
src/replicate/_models.py | 1 -
src/replicate/_utils/_typing.py | 2 +-
tests/conftest.py | 2 +-
tests/test_models.py | 2 +-
7 files changed, 10 insertions(+), 7 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index d9681a8..8eb0e54 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,7 @@ Repository = "https://github.com/replicate/replicate-python-stainless"
managed = true
# version pins are in requirements-dev.lock
dev-dependencies = [
- "pyright>=1.1.359",
+ "pyright==1.1.399",
"mypy",
"respx",
"pytest",
diff --git a/requirements-dev.lock b/requirements-dev.lock
index e9ea098..86eea12 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -69,7 +69,7 @@ pydantic-core==2.27.1
# via pydantic
pygments==2.18.0
# via rich
-pyright==1.1.392.post0
+pyright==1.1.399
pytest==8.3.3
# via pytest-asyncio
pytest-asyncio==0.24.0
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index d55fecf..3d26f28 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -98,7 +98,11 @@
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
if TYPE_CHECKING:
- from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+ from httpx._config import (
+ DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage]
+ )
+
+ HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
diff --git a/src/replicate/_models.py b/src/replicate/_models.py
index 3493571..58b9263 100644
--- a/src/replicate/_models.py
+++ b/src/replicate/_models.py
@@ -19,7 +19,6 @@
)
import pydantic
-import pydantic.generics
from pydantic.fields import FieldInfo
from ._types import (
diff --git a/src/replicate/_utils/_typing.py b/src/replicate/_utils/_typing.py
index 1958820..1bac954 100644
--- a/src/replicate/_utils/_typing.py
+++ b/src/replicate/_utils/_typing.py
@@ -110,7 +110,7 @@ class MyResponse(Foo[_T]):
```
"""
cls = cast(object, get_origin(typ) or typ)
- if cls in generic_bases:
+ if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
diff --git a/tests/conftest.py b/tests/conftest.py
index 97007f1..7d1538b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,7 +10,7 @@
from replicate import ReplicateClient, AsyncReplicateClient
if TYPE_CHECKING:
- from _pytest.fixtures import FixtureRequest
+ from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage]
pytest.register_assert_rewrite("tests.utils")
diff --git a/tests/test_models.py b/tests/test_models.py
index ee374fa..25f3d43 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -832,7 +832,7 @@ class B(BaseModel):
@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
def test_type_alias_type() -> None:
- Alias = TypeAliasType("Alias", str)
+ Alias = TypeAliasType("Alias", str) # pyright: ignore
class Model(BaseModel):
alias: Alias
From c1d6ed59ed0f06012922ec6d0bae376852523d81 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 17 Apr 2025 03:31:38 +0000
Subject: [PATCH 04/12] chore(internal): base client updates
---
src/replicate/_base_client.py | 25 +++++++++++++++++++++++++
1 file changed, 25 insertions(+)
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index 3d26f28..fb06997 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -119,6 +119,7 @@ class PageInfo:
url: URL | NotGiven
params: Query | NotGiven
+ json: Body | NotGiven
@overload
def __init__(
@@ -134,19 +135,30 @@ def __init__(
params: Query,
) -> None: ...
+ @overload
+ def __init__(
+ self,
+ *,
+ json: Body,
+ ) -> None: ...
+
def __init__(
self,
*,
url: URL | NotGiven = NOT_GIVEN,
+ json: Body | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
self.url = url
+ self.json = json
self.params = params
@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
+ if self.json:
+ return f"{self.__class__.__name__}(json={self.json})"
return f"{self.__class__.__name__}(params={self.params})"
@@ -195,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options.url = str(url)
return options
+ if not isinstance(info.json, NotGiven):
+ if not is_mapping(info.json):
+ raise TypeError("Pagination is only supported with mappings")
+
+ if not options.json_data:
+ options.json_data = {**info.json}
+ else:
+ if not is_mapping(options.json_data):
+ raise TypeError("Pagination is only supported with mappings")
+
+ options.json_data = {**options.json_data, **info.json}
+ return options
+
raise ValueError("Unexpected PageInfo state")
From a812f1b8f719c301ebf98a9d919a382e411ef247 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Thu, 17 Apr 2025 11:15:40 +0000
Subject: [PATCH 05/12] replace `retrieve` with `get` for method names
---
.stats.yml | 2 +-
api.md | 14 +-
.../resources/deployments/deployments.py | 310 ++++++------
src/replicate/resources/models/models.py | 460 +++++++++---------
src/replicate/resources/models/versions.py | 260 +++++-----
src/replicate/resources/predictions.py | 432 ++++++++--------
src/replicate/resources/trainings.py | 262 +++++-----
src/replicate/types/__init__.py | 4 +-
...response.py => deployment_get_response.py} | 4 +-
...e_response.py => training_get_response.py} | 4 +-
tests/api_resources/models/test_versions.py | 160 +++---
tests/api_resources/test_deployments.py | 210 ++++----
tests/api_resources/test_models.py | 208 ++++----
tests/api_resources/test_predictions.py | 168 +++----
tests/api_resources/test_trainings.py | 120 ++---
15 files changed, 1309 insertions(+), 1309 deletions(-)
rename src/replicate/types/{deployment_retrieve_response.py => deployment_get_response.py} (91%)
rename src/replicate/types/{training_retrieve_response.py => training_get_response.py} (94%)
diff --git a/.stats.yml b/.stats.yml
index 1448ef2..190fe36 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 27
openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
openapi_spec_hash: 4423bf747e228484547b441468a9f156
-config_hash: 64fd304bf0ff077a74053ce79a255106
+config_hash: d1d273c0d97d034d24c7eac8ef51d2ac
diff --git a/api.md b/api.md
index 509868d..e972297 100644
--- a/api.md
+++ b/api.md
@@ -11,19 +11,19 @@ Types:
```python
from replicate.types import (
DeploymentCreateResponse,
- DeploymentRetrieveResponse,
DeploymentUpdateResponse,
DeploymentListResponse,
+ DeploymentGetResponse,
)
```
Methods:
- client.deployments.create(\*\*params) -> DeploymentCreateResponse
-- client.deployments.retrieve(deployment_name, \*, deployment_owner) -> DeploymentRetrieveResponse
- client.deployments.update(deployment_name, \*, deployment_owner, \*\*params) -> DeploymentUpdateResponse
- client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse]
- client.deployments.delete(deployment_name, \*, deployment_owner) -> None
+- client.deployments.get(deployment_name, \*, deployment_owner) -> DeploymentGetResponse
- client.deployments.list_em_all() -> None
## Predictions
@@ -68,10 +68,10 @@ from replicate.types import ModelListResponse
Methods:
- client.models.create(\*\*params) -> None
-- client.models.retrieve(model_name, \*, model_owner) -> None
- client.models.list() -> SyncCursorURLPage[ModelListResponse]
- client.models.delete(model_name, \*, model_owner) -> None
- client.models.create_prediction(model_name, \*, model_owner, \*\*params) -> Prediction
+- client.models.get(model_name, \*, model_owner) -> None
## Versions
@@ -83,10 +83,10 @@ from replicate.types.models import VersionCreateTrainingResponse
Methods:
-- client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None
- client.models.versions.list(model_name, \*, model_owner) -> None
- client.models.versions.delete(version_id, \*, model_owner, model_name) -> None
- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse
+- client.models.versions.get(version_id, \*, model_owner, model_name) -> None
# Predictions
@@ -99,23 +99,23 @@ from replicate.types import Prediction, PredictionOutput, PredictionRequest
Methods:
- client.predictions.create(\*\*params) -> Prediction
-- client.predictions.retrieve(prediction_id) -> Prediction
- client.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction]
- client.predictions.cancel(prediction_id) -> None
+- client.predictions.get(prediction_id) -> Prediction
# Trainings
Types:
```python
-from replicate.types import TrainingRetrieveResponse, TrainingListResponse, TrainingCancelResponse
+from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingGetResponse
```
Methods:
-- client.trainings.retrieve(training_id) -> TrainingRetrieveResponse
- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
- client.trainings.cancel(training_id) -> TrainingCancelResponse
+- client.trainings.get(training_id) -> TrainingGetResponse
# Webhooks
diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py
index 0211b67..7e2fac7 100644
--- a/src/replicate/resources/deployments/deployments.py
+++ b/src/replicate/resources/deployments/deployments.py
@@ -28,10 +28,10 @@
)
from ...pagination import SyncCursorURLPage, AsyncCursorURLPage
from ..._base_client import AsyncPaginator, make_request_options
+from ...types.deployment_get_response import DeploymentGetResponse
from ...types.deployment_list_response import DeploymentListResponse
from ...types.deployment_create_response import DeploymentCreateResponse
from ...types.deployment_update_response import DeploymentUpdateResponse
-from ...types.deployment_retrieve_response import DeploymentRetrieveResponse
__all__ = ["DeploymentsResource", "AsyncDeploymentsResource"]
@@ -165,77 +165,6 @@ def create(
cast_to=DeploymentCreateResponse,
)
- def retrieve(
- self,
- deployment_name: str,
- *,
- deployment_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DeploymentRetrieveResponse:
- """
- Get information about a deployment by name including the current release.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
- ```
-
- The response will be a JSON object describing the deployment:
-
- ```json
- {
- "owner": "acme",
- "name": "my-app-image-generator",
- "current_release": {
- "number": 1,
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "created_at": "2024-02-15T16:32:57.018467Z",
- "created_by": {
- "type": "organization",
- "username": "acme",
- "name": "Acme Corp, Inc.",
- "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
- "github_url": "https://github.com/acme"
- },
- "configuration": {
- "hardware": "gpu-t4",
- "min_instances": 1,
- "max_instances": 5
- }
- }
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not deployment_owner:
- raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
- if not deployment_name:
- raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
- return self._get(
- f"/deployments/{deployment_owner}/{deployment_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=DeploymentRetrieveResponse,
- )
-
def update(
self,
deployment_name: str,
@@ -454,6 +383,77 @@ def delete(
cast_to=NoneType,
)
+ def get(
+ self,
+ deployment_name: str,
+ *,
+ deployment_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> DeploymentGetResponse:
+ """
+ Get information about a deployment by name including the current release.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
+ ```
+
+ The response will be a JSON object describing the deployment:
+
+ ```json
+ {
+ "owner": "acme",
+ "name": "my-app-image-generator",
+ "current_release": {
+ "number": 1,
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "created_at": "2024-02-15T16:32:57.018467Z",
+ "created_by": {
+ "type": "organization",
+ "username": "acme",
+ "name": "Acme Corp, Inc.",
+ "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
+ "github_url": "https://github.com/acme"
+ },
+ "configuration": {
+ "hardware": "gpu-t4",
+ "min_instances": 1,
+ "max_instances": 5
+ }
+ }
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not deployment_owner:
+ raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
+ if not deployment_name:
+ raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
+ return self._get(
+ f"/deployments/{deployment_owner}/{deployment_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=DeploymentGetResponse,
+ )
+
def list_em_all(
self,
*,
@@ -628,77 +628,6 @@ async def create(
cast_to=DeploymentCreateResponse,
)
- async def retrieve(
- self,
- deployment_name: str,
- *,
- deployment_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DeploymentRetrieveResponse:
- """
- Get information about a deployment by name including the current release.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
- ```
-
- The response will be a JSON object describing the deployment:
-
- ```json
- {
- "owner": "acme",
- "name": "my-app-image-generator",
- "current_release": {
- "number": 1,
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "created_at": "2024-02-15T16:32:57.018467Z",
- "created_by": {
- "type": "organization",
- "username": "acme",
- "name": "Acme Corp, Inc.",
- "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
- "github_url": "https://github.com/acme"
- },
- "configuration": {
- "hardware": "gpu-t4",
- "min_instances": 1,
- "max_instances": 5
- }
- }
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not deployment_owner:
- raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
- if not deployment_name:
- raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
- return await self._get(
- f"/deployments/{deployment_owner}/{deployment_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=DeploymentRetrieveResponse,
- )
-
async def update(
self,
deployment_name: str,
@@ -917,6 +846,77 @@ async def delete(
cast_to=NoneType,
)
+ async def get(
+ self,
+ deployment_name: str,
+ *,
+ deployment_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> DeploymentGetResponse:
+ """
+ Get information about a deployment by name including the current release.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
+ ```
+
+ The response will be a JSON object describing the deployment:
+
+ ```json
+ {
+ "owner": "acme",
+ "name": "my-app-image-generator",
+ "current_release": {
+ "number": 1,
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "created_at": "2024-02-15T16:32:57.018467Z",
+ "created_by": {
+ "type": "organization",
+ "username": "acme",
+ "name": "Acme Corp, Inc.",
+ "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
+ "github_url": "https://github.com/acme"
+ },
+ "configuration": {
+ "hardware": "gpu-t4",
+ "min_instances": 1,
+ "max_instances": 5
+ }
+ }
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not deployment_owner:
+ raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
+ if not deployment_name:
+ raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
+ return await self._get(
+ f"/deployments/{deployment_owner}/{deployment_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=DeploymentGetResponse,
+ )
+
async def list_em_all(
self,
*,
@@ -969,9 +969,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.create = to_raw_response_wrapper(
deployments.create,
)
- self.retrieve = to_raw_response_wrapper(
- deployments.retrieve,
- )
self.update = to_raw_response_wrapper(
deployments.update,
)
@@ -981,6 +978,9 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.delete = to_raw_response_wrapper(
deployments.delete,
)
+ self.get = to_raw_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = to_raw_response_wrapper(
deployments.list_em_all,
)
@@ -997,9 +997,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.create = async_to_raw_response_wrapper(
deployments.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- deployments.retrieve,
- )
self.update = async_to_raw_response_wrapper(
deployments.update,
)
@@ -1009,6 +1006,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.delete = async_to_raw_response_wrapper(
deployments.delete,
)
+ self.get = async_to_raw_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = async_to_raw_response_wrapper(
deployments.list_em_all,
)
@@ -1025,9 +1025,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.create = to_streamed_response_wrapper(
deployments.create,
)
- self.retrieve = to_streamed_response_wrapper(
- deployments.retrieve,
- )
self.update = to_streamed_response_wrapper(
deployments.update,
)
@@ -1037,6 +1034,9 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.delete = to_streamed_response_wrapper(
deployments.delete,
)
+ self.get = to_streamed_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = to_streamed_response_wrapper(
deployments.list_em_all,
)
@@ -1053,9 +1053,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.create = async_to_streamed_response_wrapper(
deployments.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- deployments.retrieve,
- )
self.update = async_to_streamed_response_wrapper(
deployments.update,
)
@@ -1065,6 +1062,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.delete = async_to_streamed_response_wrapper(
deployments.delete,
)
+ self.get = async_to_streamed_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = async_to_streamed_response_wrapper(
deployments.list_em_all,
)
diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py
index 77145ae..6d8c6eb 100644
--- a/src/replicate/resources/models/models.py
+++ b/src/replicate/resources/models/models.py
@@ -174,115 +174,6 @@ def create(
cast_to=NoneType,
)
- def retrieve(
- self,
- model_name: str,
- *,
- model_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world
- ```
-
- The response will be a model object in the following format:
-
- ```json
- {
- "url": "https://replicate.com/replicate/hello-world",
- "owner": "replicate",
- "name": "hello-world",
- "description": "A tiny model that says hello",
- "visibility": "public",
- "github_url": "https://github.com/replicate/cog-examples",
- "paper_url": null,
- "license_url": null,
- "run_count": 5681081,
- "cover_image_url": "...",
- "default_example": {...},
- "latest_version": {...},
- }
- ```
-
- The model object includes the
- [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
- for the latest version of the model.
-
- Here's an example showing how to fetch the model with cURL and display its input
- schema with [jq](https://stedolan.github.io/jq/):
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world \\
- | jq ".latest_version.openapi_schema.components.schemas.Input"
- ```
-
- This will return the following JSON object:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "type": "string",
- "title": "Text",
- "x-order": 0,
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `cover_image_url` string is an HTTPS URL for an image file. This can be:
-
- - An image uploaded by the model author.
- - The output file of the example prediction, if the model author has not set a
- cover image.
- - The input file of the example prediction, if the model author has not set a
- cover image and the example prediction has no output file.
- - A generic fallback image.
-
- The `default_example` object is a [prediction](#predictions.get) created with
- this model.
-
- The `latest_version` object is the model's most recently pushed
- [version](#models.versions.get).
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/models/{model_owner}/{model_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
*,
@@ -514,6 +405,115 @@ def create_prediction(
cast_to=Prediction,
)
+ def get(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world
+ ```
+
+ The response will be a model object in the following format:
+
+ ```json
+ {
+ "url": "https://replicate.com/replicate/hello-world",
+ "owner": "replicate",
+ "name": "hello-world",
+ "description": "A tiny model that says hello",
+ "visibility": "public",
+ "github_url": "https://github.com/replicate/cog-examples",
+ "paper_url": null,
+ "license_url": null,
+ "run_count": 5681081,
+ "cover_image_url": "...",
+ "default_example": {...},
+ "latest_version": {...},
+ }
+ ```
+
+ The model object includes the
+ [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
+ for the latest version of the model.
+
+ Here's an example showing how to fetch the model with cURL and display its input
+ schema with [jq](https://stedolan.github.io/jq/):
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world \\
+ | jq ".latest_version.openapi_schema.components.schemas.Input"
+ ```
+
+ This will return the following JSON object:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "type": "string",
+ "title": "Text",
+ "x-order": 0,
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `cover_image_url` string is an HTTPS URL for an image file. This can be:
+
+ - An image uploaded by the model author.
+ - The output file of the example prediction, if the model author has not set a
+ cover image.
+ - The input file of the example prediction, if the model author has not set a
+ cover image and the example prediction has no output file.
+ - A generic fallback image.
+
+ The `default_example` object is a [prediction](#predictions.get) created with
+ this model.
+
+ The `latest_version` object is the model's most recently pushed
+ [version](#models.versions.get).
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return self._get(
+ f"/models/{model_owner}/{model_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class AsyncModelsResource(AsyncAPIResource):
@cached_property
@@ -651,115 +651,6 @@ async def create(
cast_to=NoneType,
)
- async def retrieve(
- self,
- model_name: str,
- *,
- model_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world
- ```
-
- The response will be a model object in the following format:
-
- ```json
- {
- "url": "https://replicate.com/replicate/hello-world",
- "owner": "replicate",
- "name": "hello-world",
- "description": "A tiny model that says hello",
- "visibility": "public",
- "github_url": "https://github.com/replicate/cog-examples",
- "paper_url": null,
- "license_url": null,
- "run_count": 5681081,
- "cover_image_url": "...",
- "default_example": {...},
- "latest_version": {...},
- }
- ```
-
- The model object includes the
- [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
- for the latest version of the model.
-
- Here's an example showing how to fetch the model with cURL and display its input
- schema with [jq](https://stedolan.github.io/jq/):
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world \\
- | jq ".latest_version.openapi_schema.components.schemas.Input"
- ```
-
- This will return the following JSON object:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "type": "string",
- "title": "Text",
- "x-order": 0,
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `cover_image_url` string is an HTTPS URL for an image file. This can be:
-
- - An image uploaded by the model author.
- - The output file of the example prediction, if the model author has not set a
- cover image.
- - The input file of the example prediction, if the model author has not set a
- cover image and the example prediction has no output file.
- - A generic fallback image.
-
- The `default_example` object is a [prediction](#predictions.get) created with
- this model.
-
- The `latest_version` object is the model's most recently pushed
- [version](#models.versions.get).
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
- f"/models/{model_owner}/{model_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
*,
@@ -991,6 +882,115 @@ async def create_prediction(
cast_to=Prediction,
)
+ async def get(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world
+ ```
+
+ The response will be a model object in the following format:
+
+ ```json
+ {
+ "url": "https://replicate.com/replicate/hello-world",
+ "owner": "replicate",
+ "name": "hello-world",
+ "description": "A tiny model that says hello",
+ "visibility": "public",
+ "github_url": "https://github.com/replicate/cog-examples",
+ "paper_url": null,
+ "license_url": null,
+ "run_count": 5681081,
+ "cover_image_url": "...",
+ "default_example": {...},
+ "latest_version": {...},
+ }
+ ```
+
+ The model object includes the
+ [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
+ for the latest version of the model.
+
+ Here's an example showing how to fetch the model with cURL and display its input
+ schema with [jq](https://stedolan.github.io/jq/):
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world \\
+ | jq ".latest_version.openapi_schema.components.schemas.Input"
+ ```
+
+ This will return the following JSON object:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "type": "string",
+ "title": "Text",
+ "x-order": 0,
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `cover_image_url` string is an HTTPS URL for an image file. This can be:
+
+ - An image uploaded by the model author.
+ - The output file of the example prediction, if the model author has not set a
+ cover image.
+ - The input file of the example prediction, if the model author has not set a
+ cover image and the example prediction has no output file.
+ - A generic fallback image.
+
+ The `default_example` object is a [prediction](#predictions.get) created with
+ this model.
+
+ The `latest_version` object is the model's most recently pushed
+ [version](#models.versions.get).
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._get(
+ f"/models/{model_owner}/{model_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class ModelsResourceWithRawResponse:
def __init__(self, models: ModelsResource) -> None:
@@ -999,9 +999,6 @@ def __init__(self, models: ModelsResource) -> None:
self.create = to_raw_response_wrapper(
models.create,
)
- self.retrieve = to_raw_response_wrapper(
- models.retrieve,
- )
self.list = to_raw_response_wrapper(
models.list,
)
@@ -1011,6 +1008,9 @@ def __init__(self, models: ModelsResource) -> None:
self.create_prediction = to_raw_response_wrapper(
models.create_prediction,
)
+ self.get = to_raw_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> VersionsResourceWithRawResponse:
@@ -1024,9 +1024,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create = async_to_raw_response_wrapper(
models.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- models.retrieve,
- )
self.list = async_to_raw_response_wrapper(
models.list,
)
@@ -1036,6 +1033,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create_prediction = async_to_raw_response_wrapper(
models.create_prediction,
)
+ self.get = async_to_raw_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> AsyncVersionsResourceWithRawResponse:
@@ -1049,9 +1049,6 @@ def __init__(self, models: ModelsResource) -> None:
self.create = to_streamed_response_wrapper(
models.create,
)
- self.retrieve = to_streamed_response_wrapper(
- models.retrieve,
- )
self.list = to_streamed_response_wrapper(
models.list,
)
@@ -1061,6 +1058,9 @@ def __init__(self, models: ModelsResource) -> None:
self.create_prediction = to_streamed_response_wrapper(
models.create_prediction,
)
+ self.get = to_streamed_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> VersionsResourceWithStreamingResponse:
@@ -1074,9 +1074,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create = async_to_streamed_response_wrapper(
models.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- models.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
models.list,
)
@@ -1086,6 +1083,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create_prediction = async_to_streamed_response_wrapper(
models.create_prediction,
)
+ self.get = async_to_streamed_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> AsyncVersionsResourceWithStreamingResponse:
diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py
index 203bb5d..2065fe0 100644
--- a/src/replicate/resources/models/versions.py
+++ b/src/replicate/resources/models/versions.py
@@ -47,101 +47,6 @@ def with_streaming_response(self) -> VersionsResourceWithStreamingResponse:
"""
return VersionsResourceWithStreamingResponse(self)
- def retrieve(
- self,
- version_id: str,
- *,
- model_owner: str,
- model_name: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa
- ```
-
- The response will be the version object:
-
- ```json
- {
- "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "created_at": "2022-04-26T19:29:04.418669Z",
- "cog_version": "0.3.0",
- "openapi_schema": {...}
- }
- ```
-
- Every model describes its inputs and outputs with
- [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject)
- in the `openapi_schema` property.
-
- The `openapi_schema.components.schemas.Input` property for the
- [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
- like this:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "x-order": 0,
- "type": "string",
- "title": "Text",
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `openapi_schema.components.schemas.Output` property for the
- [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
- like this:
-
- ```json
- {
- "type": "string",
- "title": "Output"
- }
- ```
-
- For more details, see the docs on
- [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types)
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- if not version_id:
- raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/models/{model_owner}/{model_name}/versions/{version_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
model_name: str,
@@ -416,28 +321,7 @@ def create_training(
cast_to=VersionCreateTrainingResponse,
)
-
-class AsyncVersionsResource(AsyncAPIResource):
- @cached_property
- def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
- """
- return AsyncVersionsResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
- """
- return AsyncVersionsResourceWithStreamingResponse(self)
-
- async def retrieve(
+ def get(
self,
version_id: str,
*,
@@ -524,7 +408,7 @@ async def retrieve(
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
+ return self._get(
f"/models/{model_owner}/{model_name}/versions/{version_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -532,6 +416,27 @@ async def retrieve(
cast_to=NoneType,
)
+
+class AsyncVersionsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncVersionsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncVersionsResourceWithStreamingResponse(self)
+
async def list(
self,
model_name: str,
@@ -806,14 +711,106 @@ async def create_training(
cast_to=VersionCreateTrainingResponse,
)
+ async def get(
+ self,
+ version_id: str,
+ *,
+ model_owner: str,
+ model_name: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa
+ ```
+
+ The response will be the version object:
+
+ ```json
+ {
+ "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "created_at": "2022-04-26T19:29:04.418669Z",
+ "cog_version": "0.3.0",
+ "openapi_schema": {...}
+ }
+ ```
+
+ Every model describes its inputs and outputs with
+ [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject)
+ in the `openapi_schema` property.
+
+ The `openapi_schema.components.schemas.Input` property for the
+ [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
+ like this:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "x-order": 0,
+ "type": "string",
+ "title": "Text",
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `openapi_schema.components.schemas.Output` property for the
+ [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
+ like this:
+
+ ```json
+ {
+ "type": "string",
+ "title": "Output"
+ }
+ ```
+
+ For more details, see the docs on
+ [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ if not version_id:
+ raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._get(
+ f"/models/{model_owner}/{model_name}/versions/{version_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class VersionsResourceWithRawResponse:
def __init__(self, versions: VersionsResource) -> None:
self._versions = versions
- self.retrieve = to_raw_response_wrapper(
- versions.retrieve,
- )
self.list = to_raw_response_wrapper(
versions.list,
)
@@ -823,15 +820,15 @@ def __init__(self, versions: VersionsResource) -> None:
self.create_training = to_raw_response_wrapper(
versions.create_training,
)
+ self.get = to_raw_response_wrapper(
+ versions.get,
+ )
class AsyncVersionsResourceWithRawResponse:
def __init__(self, versions: AsyncVersionsResource) -> None:
self._versions = versions
- self.retrieve = async_to_raw_response_wrapper(
- versions.retrieve,
- )
self.list = async_to_raw_response_wrapper(
versions.list,
)
@@ -841,15 +838,15 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.create_training = async_to_raw_response_wrapper(
versions.create_training,
)
+ self.get = async_to_raw_response_wrapper(
+ versions.get,
+ )
class VersionsResourceWithStreamingResponse:
def __init__(self, versions: VersionsResource) -> None:
self._versions = versions
- self.retrieve = to_streamed_response_wrapper(
- versions.retrieve,
- )
self.list = to_streamed_response_wrapper(
versions.list,
)
@@ -859,15 +856,15 @@ def __init__(self, versions: VersionsResource) -> None:
self.create_training = to_streamed_response_wrapper(
versions.create_training,
)
+ self.get = to_streamed_response_wrapper(
+ versions.get,
+ )
class AsyncVersionsResourceWithStreamingResponse:
def __init__(self, versions: AsyncVersionsResource) -> None:
self._versions = versions
- self.retrieve = async_to_streamed_response_wrapper(
- versions.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
versions.list,
)
@@ -877,3 +874,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.create_training = async_to_streamed_response_wrapper(
versions.create_training,
)
+ self.get = async_to_streamed_response_wrapper(
+ versions.get,
+ )
diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py
index d6a1245..c03991d 100644
--- a/src/replicate/resources/predictions.py
+++ b/src/replicate/resources/predictions.py
@@ -189,108 +189,6 @@ def create(
cast_to=Prediction,
)
- def retrieve(
- self,
- prediction_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Get the current state of a prediction.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
- ```
-
- The response will be the prediction object:
-
- ```json
- {
- "id": "gm3qorzdhgbfurvjtvhg6dckhu",
- "model": "replicate/hello-world",
- "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "input": {
- "text": "Alice"
- },
- "logs": "",
- "output": "hello Alice",
- "error": null,
- "status": "succeeded",
- "created_at": "2023-09-08T16:19:34.765994Z",
- "data_removed": false,
- "started_at": "2023-09-08T16:19:34.779176Z",
- "completed_at": "2023-09-08T16:19:34.791859Z",
- "metrics": {
- "predict_time": 0.012683
- },
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
- "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
- }
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the prediction is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the prediction.
- - `processing`: the `predict()` method of the model is currently running.
- - `succeeded`: the prediction completed successfully.
- - `failed`: the prediction encountered an error during processing.
- - `canceled`: the prediction was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- prediction.
-
- Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the prediction used while running.
- It won't include time waiting for the prediction to start.
-
- All input parameters, output values, and logs are automatically removed after an
- hour, by default, for predictions created through the API.
-
- You must save a copy of any data or files in the output if you'd like to
- continue using them. The `output` key will still be present, but it's value will
- be `null` after the output has been removed.
-
- Output files are served by `replicate.delivery` and its subdomains. If you use
- an allow list of external domains for your assets, add `replicate.delivery` and
- `*.replicate.delivery` to it.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not prediction_id:
- raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
- return self._get(
- f"/predictions/{prediction_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
def list(
self,
*,
@@ -440,6 +338,108 @@ def cancel(
cast_to=NoneType,
)
+ def get(
+ self,
+ prediction_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Get the current state of a prediction.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
+ ```
+
+ The response will be the prediction object:
+
+ ```json
+ {
+ "id": "gm3qorzdhgbfurvjtvhg6dckhu",
+ "model": "replicate/hello-world",
+ "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "input": {
+ "text": "Alice"
+ },
+ "logs": "",
+ "output": "hello Alice",
+ "error": null,
+ "status": "succeeded",
+ "created_at": "2023-09-08T16:19:34.765994Z",
+ "data_removed": false,
+ "started_at": "2023-09-08T16:19:34.779176Z",
+ "completed_at": "2023-09-08T16:19:34.791859Z",
+ "metrics": {
+ "predict_time": 0.012683
+ },
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
+ "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
+ }
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the prediction is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the prediction.
+ - `processing`: the `predict()` method of the model is currently running.
+ - `succeeded`: the prediction completed successfully.
+ - `failed`: the prediction encountered an error during processing.
+ - `canceled`: the prediction was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ prediction.
+
+ Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the prediction used while running.
+ It won't include time waiting for the prediction to start.
+
+ All input parameters, output values, and logs are automatically removed after an
+ hour, by default, for predictions created through the API.
+
+ You must save a copy of any data or files in the output if you'd like to
+ continue using them. The `output` key will still be present, but it's value will
+ be `null` after the output has been removed.
+
+ Output files are served by `replicate.delivery` and its subdomains. If you use
+ an allow list of external domains for your assets, add `replicate.delivery` and
+ `*.replicate.delivery` to it.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not prediction_id:
+ raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
+ return self._get(
+ f"/predictions/{prediction_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
class AsyncPredictionsResource(AsyncAPIResource):
@cached_property
@@ -600,108 +600,6 @@ async def create(
cast_to=Prediction,
)
- async def retrieve(
- self,
- prediction_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Get the current state of a prediction.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
- ```
-
- The response will be the prediction object:
-
- ```json
- {
- "id": "gm3qorzdhgbfurvjtvhg6dckhu",
- "model": "replicate/hello-world",
- "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "input": {
- "text": "Alice"
- },
- "logs": "",
- "output": "hello Alice",
- "error": null,
- "status": "succeeded",
- "created_at": "2023-09-08T16:19:34.765994Z",
- "data_removed": false,
- "started_at": "2023-09-08T16:19:34.779176Z",
- "completed_at": "2023-09-08T16:19:34.791859Z",
- "metrics": {
- "predict_time": 0.012683
- },
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
- "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
- }
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the prediction is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the prediction.
- - `processing`: the `predict()` method of the model is currently running.
- - `succeeded`: the prediction completed successfully.
- - `failed`: the prediction encountered an error during processing.
- - `canceled`: the prediction was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- prediction.
-
- Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the prediction used while running.
- It won't include time waiting for the prediction to start.
-
- All input parameters, output values, and logs are automatically removed after an
- hour, by default, for predictions created through the API.
-
- You must save a copy of any data or files in the output if you'd like to
- continue using them. The `output` key will still be present, but it's value will
- be `null` after the output has been removed.
-
- Output files are served by `replicate.delivery` and its subdomains. If you use
- an allow list of external domains for your assets, add `replicate.delivery` and
- `*.replicate.delivery` to it.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not prediction_id:
- raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
- return await self._get(
- f"/predictions/{prediction_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
def list(
self,
*,
@@ -851,6 +749,108 @@ async def cancel(
cast_to=NoneType,
)
+ async def get(
+ self,
+ prediction_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Get the current state of a prediction.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
+ ```
+
+ The response will be the prediction object:
+
+ ```json
+ {
+ "id": "gm3qorzdhgbfurvjtvhg6dckhu",
+ "model": "replicate/hello-world",
+ "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "input": {
+ "text": "Alice"
+ },
+ "logs": "",
+ "output": "hello Alice",
+ "error": null,
+ "status": "succeeded",
+ "created_at": "2023-09-08T16:19:34.765994Z",
+ "data_removed": false,
+ "started_at": "2023-09-08T16:19:34.779176Z",
+ "completed_at": "2023-09-08T16:19:34.791859Z",
+ "metrics": {
+ "predict_time": 0.012683
+ },
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
+ "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
+ }
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the prediction is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the prediction.
+ - `processing`: the `predict()` method of the model is currently running.
+ - `succeeded`: the prediction completed successfully.
+ - `failed`: the prediction encountered an error during processing.
+ - `canceled`: the prediction was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ prediction.
+
+ Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the prediction used while running.
+ It won't include time waiting for the prediction to start.
+
+ All input parameters, output values, and logs are automatically removed after an
+ hour, by default, for predictions created through the API.
+
+ You must save a copy of any data or files in the output if you'd like to
+ continue using them. The `output` key will still be present, but it's value will
+ be `null` after the output has been removed.
+
+ Output files are served by `replicate.delivery` and its subdomains. If you use
+ an allow list of external domains for your assets, add `replicate.delivery` and
+ `*.replicate.delivery` to it.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not prediction_id:
+ raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
+ return await self._get(
+ f"/predictions/{prediction_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
class PredictionsResourceWithRawResponse:
def __init__(self, predictions: PredictionsResource) -> None:
@@ -859,15 +859,15 @@ def __init__(self, predictions: PredictionsResource) -> None:
self.create = to_raw_response_wrapper(
predictions.create,
)
- self.retrieve = to_raw_response_wrapper(
- predictions.retrieve,
- )
self.list = to_raw_response_wrapper(
predictions.list,
)
self.cancel = to_raw_response_wrapper(
predictions.cancel,
)
+ self.get = to_raw_response_wrapper(
+ predictions.get,
+ )
class AsyncPredictionsResourceWithRawResponse:
@@ -877,15 +877,15 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None:
self.create = async_to_raw_response_wrapper(
predictions.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- predictions.retrieve,
- )
self.list = async_to_raw_response_wrapper(
predictions.list,
)
self.cancel = async_to_raw_response_wrapper(
predictions.cancel,
)
+ self.get = async_to_raw_response_wrapper(
+ predictions.get,
+ )
class PredictionsResourceWithStreamingResponse:
@@ -895,15 +895,15 @@ def __init__(self, predictions: PredictionsResource) -> None:
self.create = to_streamed_response_wrapper(
predictions.create,
)
- self.retrieve = to_streamed_response_wrapper(
- predictions.retrieve,
- )
self.list = to_streamed_response_wrapper(
predictions.list,
)
self.cancel = to_streamed_response_wrapper(
predictions.cancel,
)
+ self.get = to_streamed_response_wrapper(
+ predictions.get,
+ )
class AsyncPredictionsResourceWithStreamingResponse:
@@ -913,12 +913,12 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None:
self.create = async_to_streamed_response_wrapper(
predictions.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- predictions.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
predictions.list,
)
self.cancel = async_to_streamed_response_wrapper(
predictions.cancel,
)
+ self.get = async_to_streamed_response_wrapper(
+ predictions.get,
+ )
diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py
index 46c5872..5a357af 100644
--- a/src/replicate/resources/trainings.py
+++ b/src/replicate/resources/trainings.py
@@ -15,9 +15,9 @@
)
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
from .._base_client import AsyncPaginator, make_request_options
+from ..types.training_get_response import TrainingGetResponse
from ..types.training_list_response import TrainingListResponse
from ..types.training_cancel_response import TrainingCancelResponse
-from ..types.training_retrieve_response import TrainingRetrieveResponse
__all__ = ["TrainingsResource", "AsyncTrainingsResource"]
@@ -42,99 +42,6 @@ def with_streaming_response(self) -> TrainingsResourceWithStreamingResponse:
"""
return TrainingsResourceWithStreamingResponse(self)
- def retrieve(
- self,
- training_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> TrainingRetrieveResponse:
- """
- Get the current state of a training.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga
- ```
-
- The response will be the training object:
-
- ```json
- {
- "completed_at": "2023-09-08T16:41:19.826523Z",
- "created_at": "2023-09-08T16:32:57.018467Z",
- "error": null,
- "id": "zz4ibbonubfz7carwiefibzgga",
- "input": {
- "input_images": "https://example.com/my-input-images.zip"
- },
- "logs": "...",
- "metrics": {
- "predict_time": 502.713876
- },
- "output": {
- "version": "...",
- "weights": "..."
- },
- "started_at": "2023-09-08T16:32:57.112647Z",
- "status": "succeeded",
- "urls": {
- "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
- "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
- },
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the training is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the training.
- - `processing`: the `train()` method of the model is currently running.
- - `succeeded`: the training completed successfully.
- - `failed`: the training encountered an error during processing.
- - `canceled`: the training was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- training.
-
- Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the training used while running. It
- won't include time waiting for the training to start.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not training_id:
- raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- return self._get(
- f"/trainings/{training_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=TrainingRetrieveResponse,
- )
-
def list(
self,
*,
@@ -252,28 +159,7 @@ def cancel(
cast_to=TrainingCancelResponse,
)
-
-class AsyncTrainingsResource(AsyncAPIResource):
- @cached_property
- def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
- """
- return AsyncTrainingsResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
- """
- return AsyncTrainingsResourceWithStreamingResponse(self)
-
- async def retrieve(
+ def get(
self,
training_id: str,
*,
@@ -283,7 +169,7 @@ async def retrieve(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> TrainingRetrieveResponse:
+ ) -> TrainingGetResponse:
"""
Get the current state of a training.
@@ -358,14 +244,35 @@ async def retrieve(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- return await self._get(
+ return self._get(
f"/trainings/{training_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=TrainingRetrieveResponse,
+ cast_to=TrainingGetResponse,
)
+
+class AsyncTrainingsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncTrainingsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncTrainingsResourceWithStreamingResponse(self)
+
def list(
self,
*,
@@ -483,62 +390,155 @@ async def cancel(
cast_to=TrainingCancelResponse,
)
+ async def get(
+ self,
+ training_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> TrainingGetResponse:
+ """
+ Get the current state of a training.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga
+ ```
+
+ The response will be the training object:
+
+ ```json
+ {
+ "completed_at": "2023-09-08T16:41:19.826523Z",
+ "created_at": "2023-09-08T16:32:57.018467Z",
+ "error": null,
+ "id": "zz4ibbonubfz7carwiefibzgga",
+ "input": {
+ "input_images": "https://example.com/my-input-images.zip"
+ },
+ "logs": "...",
+ "metrics": {
+ "predict_time": 502.713876
+ },
+ "output": {
+ "version": "...",
+ "weights": "..."
+ },
+ "started_at": "2023-09-08T16:32:57.112647Z",
+ "status": "succeeded",
+ "urls": {
+ "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
+ "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
+ },
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the training is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the training.
+ - `processing`: the `train()` method of the model is currently running.
+ - `succeeded`: the training completed successfully.
+ - `failed`: the training encountered an error during processing.
+ - `canceled`: the training was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ training.
+
+ Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the training used while running. It
+ won't include time waiting for the training to start.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not training_id:
+ raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
+ return await self._get(
+ f"/trainings/{training_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=TrainingGetResponse,
+ )
+
class TrainingsResourceWithRawResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = to_raw_response_wrapper(
- trainings.retrieve,
- )
self.list = to_raw_response_wrapper(
trainings.list,
)
self.cancel = to_raw_response_wrapper(
trainings.cancel,
)
+ self.get = to_raw_response_wrapper(
+ trainings.get,
+ )
class AsyncTrainingsResourceWithRawResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = async_to_raw_response_wrapper(
- trainings.retrieve,
- )
self.list = async_to_raw_response_wrapper(
trainings.list,
)
self.cancel = async_to_raw_response_wrapper(
trainings.cancel,
)
+ self.get = async_to_raw_response_wrapper(
+ trainings.get,
+ )
class TrainingsResourceWithStreamingResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = to_streamed_response_wrapper(
- trainings.retrieve,
- )
self.list = to_streamed_response_wrapper(
trainings.list,
)
self.cancel = to_streamed_response_wrapper(
trainings.cancel,
)
+ self.get = to_streamed_response_wrapper(
+ trainings.get,
+ )
class AsyncTrainingsResourceWithStreamingResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = async_to_streamed_response_wrapper(
- trainings.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
trainings.list,
)
self.cancel = async_to_streamed_response_wrapper(
trainings.cancel,
)
+ self.get = async_to_streamed_response_wrapper(
+ trainings.get,
+ )
diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py
index 7d2d8f4..fa8ee5d 100644
--- a/src/replicate/types/__init__.py
+++ b/src/replicate/types/__init__.py
@@ -7,9 +7,11 @@
from .model_create_params import ModelCreateParams as ModelCreateParams
from .model_list_response import ModelListResponse as ModelListResponse
from .account_list_response import AccountListResponse as AccountListResponse
+from .training_get_response import TrainingGetResponse as TrainingGetResponse
from .hardware_list_response import HardwareListResponse as HardwareListResponse
from .prediction_list_params import PredictionListParams as PredictionListParams
from .training_list_response import TrainingListResponse as TrainingListResponse
+from .deployment_get_response import DeploymentGetResponse as DeploymentGetResponse
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
@@ -17,6 +19,4 @@
from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
-from .training_retrieve_response import TrainingRetrieveResponse as TrainingRetrieveResponse
-from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse
from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams
diff --git a/src/replicate/types/deployment_retrieve_response.py b/src/replicate/types/deployment_get_response.py
similarity index 91%
rename from src/replicate/types/deployment_retrieve_response.py
rename to src/replicate/types/deployment_get_response.py
index 2ecc13c..a7281ff 100644
--- a/src/replicate/types/deployment_retrieve_response.py
+++ b/src/replicate/types/deployment_get_response.py
@@ -6,7 +6,7 @@
from .._models import BaseModel
-__all__ = ["DeploymentRetrieveResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"]
+__all__ = ["DeploymentGetResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"]
class CurrentReleaseConfiguration(BaseModel):
@@ -55,7 +55,7 @@ class CurrentRelease(BaseModel):
"""The ID of the model version used in the release."""
-class DeploymentRetrieveResponse(BaseModel):
+class DeploymentGetResponse(BaseModel):
current_release: Optional[CurrentRelease] = None
name: Optional[str] = None
diff --git a/src/replicate/types/training_retrieve_response.py b/src/replicate/types/training_get_response.py
similarity index 94%
rename from src/replicate/types/training_retrieve_response.py
rename to src/replicate/types/training_get_response.py
index 55b1173..4169da7 100644
--- a/src/replicate/types/training_retrieve_response.py
+++ b/src/replicate/types/training_get_response.py
@@ -6,7 +6,7 @@
from .._models import BaseModel
-__all__ = ["TrainingRetrieveResponse", "Metrics", "Output", "URLs"]
+__all__ = ["TrainingGetResponse", "Metrics", "Output", "URLs"]
class Metrics(BaseModel):
@@ -30,7 +30,7 @@ class URLs(BaseModel):
"""URL to get the training details"""
-class TrainingRetrieveResponse(BaseModel):
+class TrainingGetResponse(BaseModel):
id: Optional[str] = None
"""The unique ID of the training"""
diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py
index 7e6d583..d1fb7a8 100644
--- a/tests/api_resources/models/test_versions.py
+++ b/tests/api_resources/models/test_versions.py
@@ -17,70 +17,6 @@
class TestVersions:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- version = client.models.versions.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- )
- assert version is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = response.parse()
- assert version is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.models.versions.with_streaming_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- version = response.parse()
- assert version is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="",
- model_name="model_name",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="",
- model_owner="model_owner",
- model_name="model_name",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -287,14 +223,10 @@ def test_path_params_create_training(self, client: ReplicateClient) -> None:
input={},
)
-
-class TestAsyncVersions:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
@pytest.mark.skip()
@parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- version = await async_client.models.versions.retrieve(
+ def test_method_get(self, client: ReplicateClient) -> None:
+ version = client.models.versions.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -303,8 +235,8 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None
@pytest.mark.skip()
@parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.versions.with_raw_response.retrieve(
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -312,13 +244,13 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = await response.parse()
+ version = response.parse()
assert version is None
@pytest.mark.skip()
@parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.versions.with_streaming_response.retrieve(
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.models.versions.with_streaming_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -326,35 +258,39 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = await response.parse()
+ version = response.parse()
assert version is None
assert cast(Any, response.is_closed) is True
@pytest.mark.skip()
@parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
+ def test_path_params_get(self, client: ReplicateClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="",
model_name="model_name",
)
with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="",
)
with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="",
model_owner="model_owner",
model_name="model_name",
)
+
+class TestAsyncVersions:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -560,3 +496,67 @@ async def test_path_params_create_training(self, async_client: AsyncReplicateCli
destination="destination",
input={},
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ version = await async_client.models.versions.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
+ assert version is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ version = await response.parse()
+ assert version is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.versions.with_streaming_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ version = await response.parse()
+ assert version is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="",
+ model_name="model_name",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py
index b834241..3a01dcb 100644
--- a/tests/api_resources/test_deployments.py
+++ b/tests/api_resources/test_deployments.py
@@ -10,10 +10,10 @@
from replicate import ReplicateClient, AsyncReplicateClient
from tests.utils import assert_matches_type
from replicate.types import (
+ DeploymentGetResponse,
DeploymentListResponse,
DeploymentCreateResponse,
DeploymentUpdateResponse,
- DeploymentRetrieveResponse,
)
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
@@ -72,58 +72,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- deployment = client.deployments.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.deployments.with_streaming_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
- client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
- client.deployments.with_raw_response.retrieve(
- deployment_name="",
- deployment_owner="deployment_owner",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_update(self, client: ReplicateClient) -> None:
@@ -269,6 +217,58 @@ def test_path_params_delete(self, client: ReplicateClient) -> None:
deployment_owner="deployment_owner",
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ deployment = client.deployments.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ deployment = response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.deployments.with_streaming_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ deployment = response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
+ client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
+ client.deployments.with_raw_response.get(
+ deployment_name="",
+ deployment_owner="deployment_owner",
+ )
+
@pytest.mark.skip()
@parametrize
def test_method_list_em_all(self, client: ReplicateClient) -> None:
@@ -350,58 +350,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- deployment = await async_client.deployments.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = await response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.deployments.with_streaming_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = await response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
- await async_client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
- await async_client.deployments.with_raw_response.retrieve(
- deployment_name="",
- deployment_owner="deployment_owner",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_update(self, async_client: AsyncReplicateClient) -> None:
@@ -547,6 +495,58 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N
deployment_owner="deployment_owner",
)
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ deployment = await async_client.deployments.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ deployment = await response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.deployments.with_streaming_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ deployment = await response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
+ await async_client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
+ await async_client.deployments.with_raw_response.get(
+ deployment_name="",
+ deployment_owner="deployment_owner",
+ )
+
@pytest.mark.skip()
@parametrize
async def test_method_list_em_all(self, async_client: AsyncReplicateClient) -> None:
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
index f56f00c..a2a9f52 100644
--- a/tests/api_resources/test_models.py
+++ b/tests/api_resources/test_models.py
@@ -77,58 +77,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- model = client.models.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = response.parse()
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.models.with_streaming_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = response.parse()
- assert model is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.with_raw_response.retrieve(
- model_name="",
- model_owner="model_owner",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -280,6 +228,58 @@ def test_path_params_create_prediction(self, client: ReplicateClient) -> None:
input={},
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ model = client.models.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.models.with_streaming_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ client.models.with_raw_response.get(
+ model_name="",
+ model_owner="model_owner",
+ )
+
class TestAsyncModels:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -343,58 +343,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- model = await async_client.models.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = await response.parse()
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.with_streaming_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = await response.parse()
- assert model is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.with_raw_response.retrieve(
- model_name="",
- model_owner="model_owner",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -545,3 +493,55 @@ async def test_path_params_create_prediction(self, async_client: AsyncReplicateC
model_owner="model_owner",
input={},
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ model = await async_client.models.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = await response.parse()
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.with_streaming_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = await response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.models.with_raw_response.get(
+ model_name="",
+ model_owner="model_owner",
+ )
diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py
index 8fc4142..51bd80d 100644
--- a/tests/api_resources/test_predictions.py
+++ b/tests/api_resources/test_predictions.py
@@ -69,48 +69,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- prediction = client.predictions.retrieve(
- "prediction_id",
- )
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.predictions.with_raw_response.retrieve(
- "prediction_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- prediction = response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.predictions.with_streaming_response.retrieve(
- "prediction_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- prediction = response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
- client.predictions.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -190,6 +148,48 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None:
"",
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ prediction = client.predictions.get(
+ "prediction_id",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.predictions.with_raw_response.get(
+ "prediction_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.predictions.with_streaming_response.get(
+ "prediction_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
+ client.predictions.with_raw_response.get(
+ "",
+ )
+
class TestAsyncPredictions:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -244,48 +244,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- prediction = await async_client.predictions.retrieve(
- "prediction_id",
- )
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.predictions.with_raw_response.retrieve(
- "prediction_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- prediction = await response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.predictions.with_streaming_response.retrieve(
- "prediction_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- prediction = await response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
- await async_client.predictions.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -364,3 +322,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N
await async_client.predictions.with_raw_response.cancel(
"",
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ prediction = await async_client.predictions.get(
+ "prediction_id",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.predictions.with_raw_response.get(
+ "prediction_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.predictions.with_streaming_response.get(
+ "prediction_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
+ await async_client.predictions.with_raw_response.get(
+ "",
+ )
diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py
index 68b1a85..f2dadb1 100644
--- a/tests/api_resources/test_trainings.py
+++ b/tests/api_resources/test_trainings.py
@@ -9,7 +9,7 @@
from replicate import ReplicateClient, AsyncReplicateClient
from tests.utils import assert_matches_type
-from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingRetrieveResponse
+from replicate.types import TrainingGetResponse, TrainingListResponse, TrainingCancelResponse
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -18,48 +18,6 @@
class TestTrainings:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- training = client.trainings.retrieve(
- "training_id",
- )
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.trainings.with_raw_response.retrieve(
- "training_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = response.parse()
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.trainings.with_streaming_response.retrieve(
- "training_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- training = response.parse()
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
- client.trainings.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -130,52 +88,52 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None:
"",
)
-
-class TestAsyncTrainings:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
@pytest.mark.skip()
@parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- training = await async_client.trainings.retrieve(
+ def test_method_get(self, client: ReplicateClient) -> None:
+ training = client.trainings.get(
"training_id",
)
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.trainings.with_raw_response.retrieve(
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.trainings.with_raw_response.get(
"training_id",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = await response.parse()
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
+ training = response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.trainings.with_streaming_response.retrieve(
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.trainings.with_streaming_response.get(
"training_id",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = await response.parse()
- assert_matches_type(TrainingRetrieveResponse, training, path=["response"])
+ training = response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@pytest.mark.skip()
@parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
+ def test_path_params_get(self, client: ReplicateClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
- await async_client.trainings.with_raw_response.retrieve(
+ client.trainings.with_raw_response.get(
"",
)
+
+class TestAsyncTrainings:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -245,3 +203,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N
await async_client.trainings.with_raw_response.cancel(
"",
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ training = await async_client.trainings.get(
+ "training_id",
+ )
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.trainings.with_raw_response.get(
+ "training_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ training = await response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.trainings.with_streaming_response.get(
+ "training_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ training = await response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
+ await async_client.trainings.with_raw_response.get(
+ "",
+ )
From fc34c6d4fc36a41441ab8417f85343e640b53b76 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Sat, 19 Apr 2025 03:40:50 +0000
Subject: [PATCH 06/12] chore(internal): update models test
---
tests/test_models.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/tests/test_models.py b/tests/test_models.py
index 25f3d43..cf8173c 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -492,12 +492,15 @@ class Model(BaseModel):
resource_id: Optional[str] = None
m = Model.construct()
+ assert m.resource_id is None
assert "resource_id" not in m.model_fields_set
m = Model.construct(resource_id=None)
+ assert m.resource_id is None
assert "resource_id" in m.model_fields_set
m = Model.construct(resource_id="foo")
+ assert m.resource_id == "foo"
assert "resource_id" in m.model_fields_set
From 1bad4d3d3676a323032f37f0195ff640fcce3458 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:34:13 +0000
Subject: [PATCH 07/12] chore(ci): add timeout thresholds for CI jobs
---
.github/workflows/ci.yml | 2 ++
1 file changed, 2 insertions(+)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 81f6dc2..04b083c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -10,6 +10,7 @@ on:
jobs:
lint:
+ timeout-minutes: 10
name: lint
runs-on: ubuntu-latest
steps:
@@ -30,6 +31,7 @@ jobs:
run: ./scripts/lint
test:
+ timeout-minutes: 10
name: test
runs-on: ubuntu-latest
steps:
From 4cdf515372a9e936c3a18afd24a444a778b1f7f5 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:34:45 +0000
Subject: [PATCH 08/12] chore(internal): import reformatting
---
src/replicate/_client.py | 5 +----
src/replicate/resources/deployments/deployments.py | 5 +----
src/replicate/resources/deployments/predictions.py | 6 +-----
src/replicate/resources/models/models.py | 6 +-----
src/replicate/resources/models/versions.py | 5 +----
src/replicate/resources/predictions.py | 6 +-----
6 files changed, 6 insertions(+), 27 deletions(-)
diff --git a/src/replicate/_client.py b/src/replicate/_client.py
index b59e805..6afa0ec 100644
--- a/src/replicate/_client.py
+++ b/src/replicate/_client.py
@@ -19,10 +19,7 @@
ProxiesTypes,
RequestOptions,
)
-from ._utils import (
- is_given,
- get_async_library,
-)
+from ._utils import is_given, get_async_library
from ._version import __version__
from .resources import accounts, hardware, trainings, collections, predictions
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py
index 7e2fac7..565a1c6 100644
--- a/src/replicate/resources/deployments/deployments.py
+++ b/src/replicate/resources/deployments/deployments.py
@@ -6,10 +6,7 @@
from ...types import deployment_create_params, deployment_update_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py
index 7177332..c94fff0 100644
--- a/src/replicate/resources/deployments/predictions.py
+++ b/src/replicate/resources/deployments/predictions.py
@@ -8,11 +8,7 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py
index 6d8c6eb..435101f 100644
--- a/src/replicate/resources/models/models.py
+++ b/src/replicate/resources/models/models.py
@@ -9,11 +9,7 @@
from ...types import model_create_params, model_create_prediction_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
from .versions import (
VersionsResource,
AsyncVersionsResource,
diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py
index 2065fe0..308d04b 100644
--- a/src/replicate/resources/models/versions.py
+++ b/src/replicate/resources/models/versions.py
@@ -8,10 +8,7 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py
index c03991d..abe0ca1 100644
--- a/src/replicate/resources/predictions.py
+++ b/src/replicate/resources/predictions.py
@@ -10,11 +10,7 @@
from ..types import prediction_list_params, prediction_create_params
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from .._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from .._utils import maybe_transform, strip_not_given, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
From 2918ebad39df868485fed02a2d0020bef72d24b9 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:35:56 +0000
Subject: [PATCH 09/12] chore(internal): fix list file params
---
src/replicate/_utils/_utils.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/src/replicate/_utils/_utils.py b/src/replicate/_utils/_utils.py
index e5811bb..ea3cf3f 100644
--- a/src/replicate/_utils/_utils.py
+++ b/src/replicate/_utils/_utils.py
@@ -72,8 +72,16 @@ def _extract_items(
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
- assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
+
+ if is_list(obj):
+ files: list[tuple[str, FileTypes]] = []
+ for entry in obj:
+ assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
+ files.append((flattened_key + "[]", cast(FileTypes, entry)))
+ return files
+
+ assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
From 75005e11045385d0596911bbbbb062207450bd14 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:36:31 +0000
Subject: [PATCH 10/12] chore(internal): refactor retries to not use recursion
---
src/replicate/_base_client.py | 414 ++++++++++++++--------------------
1 file changed, 175 insertions(+), 239 deletions(-)
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index fb06997..84db2c9 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -437,8 +437,7 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
headers = httpx.Headers(headers_dict)
idempotency_header = self._idempotency_header
- if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
- options.idempotency_key = options.idempotency_key or self._idempotency_key()
+ if idempotency_header and options.idempotency_key and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key
# Don't set these headers if they were already set or removed by the caller. We check
@@ -903,7 +902,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[True],
stream_cls: Type[_StreamT],
@@ -914,7 +912,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[False] = False,
) -> ResponseT: ...
@@ -924,7 +921,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: Type[_StreamT] | None = None,
@@ -934,125 +930,109 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
+ cast_to = self._maybe_override_cast_to(cast_to, options)
- def _request(
- self,
- *,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- retries_taken: int,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
-
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = self._prepare_options(options)
-
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- self._prepare_request(request)
-
- if options.idempotency_key:
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
# ensure the idempotency key is reused between requests
- input_options.idempotency_key = options.idempotency_key
+ input_options.idempotency_key = self._idempotency_key()
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- log.debug("Sending HTTP Request: %s %s", request.method, request.url)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = self._prepare_options(options)
- try:
- response = self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ self._prepare_request(request)
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
+ response = None
+ try:
+ response = self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
)
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
-
- log.debug(
- 'HTTP Response: %s %s "%i %s" %s',
- request.method,
- request.url,
- response.status_code,
- response.reason_phrase,
- response.headers,
- )
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ err.response.close()
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- err.response.close()
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ err.response.read()
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- err.response.read()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return self._process_response(
cast_to=cast_to,
options=options,
@@ -1062,37 +1042,20 @@ def _request(
retries_taken=retries_taken,
)
- def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
- # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
- # different thread if necessary.
time.sleep(timeout)
- return self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
def _process_response(
self,
*,
@@ -1436,7 +1399,6 @@ async def request(
options: FinalRequestOptions,
*,
stream: Literal[False] = False,
- remaining_retries: Optional[int] = None,
) -> ResponseT: ...
@overload
@@ -1447,7 +1409,6 @@ async def request(
*,
stream: Literal[True],
stream_cls: type[_AsyncStreamT],
- remaining_retries: Optional[int] = None,
) -> _AsyncStreamT: ...
@overload
@@ -1458,7 +1419,6 @@ async def request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT: ...
async def request(
@@ -1468,120 +1428,111 @@ async def request(
*,
stream: bool = False,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
- ) -> ResponseT | _AsyncStreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return await self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
-
- async def _request(
- self,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- *,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()
+ cast_to = self._maybe_override_cast_to(cast_to, options)
+
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
-
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = await self._prepare_options(options)
-
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- await self._prepare_request(request)
-
- if options.idempotency_key:
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
# ensure the idempotency key is reused between requests
- input_options.idempotency_key = options.idempotency_key
+ input_options.idempotency_key = self._idempotency_key()
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- try:
- response = await self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = await self._prepare_options(options)
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ await self._prepare_request(request)
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
+ response = None
+ try:
+ response = await self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
+ )
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug(
- 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
- )
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ await err.response.aclose()
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- await err.response.aclose()
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ await err.response.aread()
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- await err.response.aread()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return await self._process_response(
cast_to=cast_to,
options=options,
@@ -1591,35 +1542,20 @@ async def _request(
retries_taken=retries_taken,
)
- async def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- ) -> ResponseT | _AsyncStreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ async def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
await anyio.sleep(timeout)
- return await self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
async def _process_response(
self,
*,
From c907599a6736e781f3f80062eb4d03ed92f03403 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:36:56 +0000
Subject: [PATCH 11/12] fix(pydantic v1): more robust ModelField.annotation
check
---
src/replicate/_models.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/replicate/_models.py b/src/replicate/_models.py
index 58b9263..798956f 100644
--- a/src/replicate/_models.py
+++ b/src/replicate/_models.py
@@ -626,8 +626,8 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
- if field_info.annotation and is_literal_type(field_info.annotation):
- for entry in get_args(field_info.annotation):
+ if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
+ for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
From 8d82164749ccdcb0e1fd9700659281a3db5abfb2 Mon Sep 17 00:00:00 2001
From: "stainless-app[bot]"
<142633134+stainless-app[bot]@users.noreply.github.com>
Date: Wed, 23 Apr 2025 04:37:12 +0000
Subject: [PATCH 12/12] release: 0.1.0-alpha.3
---
.release-please-manifest.json | 2 +-
CHANGELOG.md | 29 +++++++++++++++++++++++++++++
pyproject.toml | 2 +-
src/replicate/_version.py | 2 +-
4 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index f14b480..aaf968a 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.1.0-alpha.2"
+ ".": "0.1.0-alpha.3"
}
\ No newline at end of file
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d51ac7f..6bde67f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,34 @@
# Changelog
+## 0.1.0-alpha.3 (2025-04-23)
+
+Full Changelog: [v0.1.0-alpha.2...v0.1.0-alpha.3](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.2...v0.1.0-alpha.3)
+
+### ⚠ BREAKING CHANGES
+
+* **api:** use correct env var for bearer token
+
+### Features
+
+* **api:** api update ([7ebd598](https://github.com/replicate/replicate-python-stainless/commit/7ebd59873181c74dbaa035ac599abcbbefb3ee62))
+
+
+### Bug Fixes
+
+* **api:** use correct env var for bearer token ([00eab77](https://github.com/replicate/replicate-python-stainless/commit/00eab7702f8f2699ce9b3070f23202278ac21855))
+* **pydantic v1:** more robust ModelField.annotation check ([c907599](https://github.com/replicate/replicate-python-stainless/commit/c907599a6736e781f3f80062eb4d03ed92f03403))
+
+
+### Chores
+
+* **ci:** add timeout thresholds for CI jobs ([1bad4d3](https://github.com/replicate/replicate-python-stainless/commit/1bad4d3d3676a323032f37f0195ff640fcce3458))
+* **internal:** base client updates ([c1d6ed5](https://github.com/replicate/replicate-python-stainless/commit/c1d6ed59ed0f06012922ec6d0bae376852523d81))
+* **internal:** bump pyright version ([f1e4d14](https://github.com/replicate/replicate-python-stainless/commit/f1e4d140104ff317b94cb2dd88ec850a9b8bce54))
+* **internal:** fix list file params ([2918eba](https://github.com/replicate/replicate-python-stainless/commit/2918ebad39df868485fed02a2d0020bef72d24b9))
+* **internal:** import reformatting ([4cdf515](https://github.com/replicate/replicate-python-stainless/commit/4cdf515372a9e936c3a18afd24a444a778b1f7f5))
+* **internal:** refactor retries to not use recursion ([75005e1](https://github.com/replicate/replicate-python-stainless/commit/75005e11045385d0596911bbbbb062207450bd14))
+* **internal:** update models test ([fc34c6d](https://github.com/replicate/replicate-python-stainless/commit/fc34c6d4fc36a41441ab8417f85343e640b53b76))
+
## 0.1.0-alpha.2 (2025-04-16)
Full Changelog: [v0.1.0-alpha.1...v0.1.0-alpha.2](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.1...v0.1.0-alpha.2)
diff --git a/pyproject.toml b/pyproject.toml
index 8eb0e54..4376f13 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "replicate-stainless"
-version = "0.1.0-alpha.2"
+version = "0.1.0-alpha.3"
description = "The official Python library for the replicate-client API"
dynamic = ["readme"]
license = "Apache-2.0"
diff --git a/src/replicate/_version.py b/src/replicate/_version.py
index dfeb99c..5b5e6c4 100644
--- a/src/replicate/_version.py
+++ b/src/replicate/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "replicate"
-__version__ = "0.1.0-alpha.2" # x-release-please-version
+__version__ = "0.1.0-alpha.3" # x-release-please-version