diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 81f6dc2..04b083c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -10,6 +10,7 @@ on:
jobs:
lint:
+ timeout-minutes: 10
name: lint
runs-on: ubuntu-latest
steps:
@@ -30,6 +31,7 @@ jobs:
run: ./scripts/lint
test:
+ timeout-minutes: 10
name: test
runs-on: ubuntu-latest
steps:
diff --git a/.release-please-manifest.json b/.release-please-manifest.json
index f14b480..aaf968a 100644
--- a/.release-please-manifest.json
+++ b/.release-please-manifest.json
@@ -1,3 +1,3 @@
{
- ".": "0.1.0-alpha.2"
+ ".": "0.1.0-alpha.3"
}
\ No newline at end of file
diff --git a/.stats.yml b/.stats.yml
index 91aadf3..190fe36 100644
--- a/.stats.yml
+++ b/.stats.yml
@@ -1,4 +1,4 @@
configured_endpoints: 27
-openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml
-openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3
-config_hash: 810de4c2eee1a7649263cff01f00da7c
+openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml
+openapi_spec_hash: 4423bf747e228484547b441468a9f156
+config_hash: d1d273c0d97d034d24c7eac8ef51d2ac
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d51ac7f..6bde67f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,34 @@
# Changelog
+## 0.1.0-alpha.3 (2025-04-23)
+
+Full Changelog: [v0.1.0-alpha.2...v0.1.0-alpha.3](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.2...v0.1.0-alpha.3)
+
+### ⚠ BREAKING CHANGES
+
+* **api:** use correct env var for bearer token
+
+### Features
+
+* **api:** api update ([7ebd598](https://github.com/replicate/replicate-python-stainless/commit/7ebd59873181c74dbaa035ac599abcbbefb3ee62))
+
+
+### Bug Fixes
+
+* **api:** use correct env var for bearer token ([00eab77](https://github.com/replicate/replicate-python-stainless/commit/00eab7702f8f2699ce9b3070f23202278ac21855))
+* **pydantic v1:** more robust ModelField.annotation check ([c907599](https://github.com/replicate/replicate-python-stainless/commit/c907599a6736e781f3f80062eb4d03ed92f03403))
+
+
+### Chores
+
+* **ci:** add timeout thresholds for CI jobs ([1bad4d3](https://github.com/replicate/replicate-python-stainless/commit/1bad4d3d3676a323032f37f0195ff640fcce3458))
+* **internal:** base client updates ([c1d6ed5](https://github.com/replicate/replicate-python-stainless/commit/c1d6ed59ed0f06012922ec6d0bae376852523d81))
+* **internal:** bump pyright version ([f1e4d14](https://github.com/replicate/replicate-python-stainless/commit/f1e4d140104ff317b94cb2dd88ec850a9b8bce54))
+* **internal:** fix list file params ([2918eba](https://github.com/replicate/replicate-python-stainless/commit/2918ebad39df868485fed02a2d0020bef72d24b9))
+* **internal:** import reformatting ([4cdf515](https://github.com/replicate/replicate-python-stainless/commit/4cdf515372a9e936c3a18afd24a444a778b1f7f5))
+* **internal:** refactor retries to not use recursion ([75005e1](https://github.com/replicate/replicate-python-stainless/commit/75005e11045385d0596911bbbbb062207450bd14))
+* **internal:** update models test ([fc34c6d](https://github.com/replicate/replicate-python-stainless/commit/fc34c6d4fc36a41441ab8417f85343e640b53b76))
+
## 0.1.0-alpha.2 (2025-04-16)
Full Changelog: [v0.1.0-alpha.1...v0.1.0-alpha.2](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.1...v0.1.0-alpha.2)
diff --git a/README.md b/README.md
index 0015b4c..23b57ae 100644
--- a/README.md
+++ b/README.md
@@ -28,9 +28,7 @@ import os
from replicate import ReplicateClient
client = ReplicateClient(
- bearer_token=os.environ.get(
- "REPLICATE_CLIENT_BEARER_TOKEN"
- ), # This is the default and can be omitted
+ bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
accounts = client.accounts.list()
@@ -39,7 +37,7 @@ print(accounts.type)
While you can provide a `bearer_token` keyword argument,
we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/)
-to add `REPLICATE_CLIENT_BEARER_TOKEN="My Bearer Token"` to your `.env` file
+to add `REPLICATE_API_TOKEN="My Bearer Token"` to your `.env` file
so that your Bearer Token is not stored in source control.
## Async usage
@@ -52,9 +50,7 @@ import asyncio
from replicate import AsyncReplicateClient
client = AsyncReplicateClient(
- bearer_token=os.environ.get(
- "REPLICATE_CLIENT_BEARER_TOKEN"
- ), # This is the default and can be omitted
+ bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted
)
diff --git a/api.md b/api.md
index 60e91e6..e972297 100644
--- a/api.md
+++ b/api.md
@@ -11,19 +11,19 @@ Types:
```python
from replicate.types import (
DeploymentCreateResponse,
- DeploymentRetrieveResponse,
DeploymentUpdateResponse,
DeploymentListResponse,
+ DeploymentGetResponse,
)
```
Methods:
- client.deployments.create(\*\*params) -> DeploymentCreateResponse
-- client.deployments.retrieve(deployment_name, \*, deployment_owner) -> DeploymentRetrieveResponse
- client.deployments.update(deployment_name, \*, deployment_owner, \*\*params) -> DeploymentUpdateResponse
- client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse]
- client.deployments.delete(deployment_name, \*, deployment_owner) -> None
+- client.deployments.get(deployment_name, \*, deployment_owner) -> DeploymentGetResponse
- client.deployments.list_em_all() -> None
## Predictions
@@ -68,19 +68,25 @@ from replicate.types import ModelListResponse
Methods:
- client.models.create(\*\*params) -> None
-- client.models.retrieve(model_name, \*, model_owner) -> None
- client.models.list() -> SyncCursorURLPage[ModelListResponse]
- client.models.delete(model_name, \*, model_owner) -> None
- client.models.create_prediction(model_name, \*, model_owner, \*\*params) -> Prediction
+- client.models.get(model_name, \*, model_owner) -> None
## Versions
+Types:
+
+```python
+from replicate.types.models import VersionCreateTrainingResponse
+```
+
Methods:
-- client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None
- client.models.versions.list(model_name, \*, model_owner) -> None
- client.models.versions.delete(version_id, \*, model_owner, model_name) -> None
-- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> None
+- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse
+- client.models.versions.get(version_id, \*, model_owner, model_name) -> None
# Predictions
@@ -93,17 +99,23 @@ from replicate.types import Prediction, PredictionOutput, PredictionRequest
Methods:
- client.predictions.create(\*\*params) -> Prediction
-- client.predictions.retrieve(prediction_id) -> Prediction
- client.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction]
- client.predictions.cancel(prediction_id) -> None
+- client.predictions.get(prediction_id) -> Prediction
# Trainings
+Types:
+
+```python
+from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingGetResponse
+```
+
Methods:
-- client.trainings.retrieve(training_id) -> None
-- client.trainings.list() -> None
-- client.trainings.cancel(training_id) -> None
+- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse]
+- client.trainings.cancel(training_id) -> TrainingCancelResponse
+- client.trainings.get(training_id) -> TrainingGetResponse
# Webhooks
diff --git a/pyproject.toml b/pyproject.toml
index d9681a8..4376f13 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "replicate-stainless"
-version = "0.1.0-alpha.2"
+version = "0.1.0-alpha.3"
description = "The official Python library for the replicate-client API"
dynamic = ["readme"]
license = "Apache-2.0"
@@ -42,7 +42,7 @@ Repository = "https://github.com/replicate/replicate-python-stainless"
managed = true
# version pins are in requirements-dev.lock
dev-dependencies = [
- "pyright>=1.1.359",
+ "pyright==1.1.399",
"mypy",
"respx",
"pytest",
diff --git a/requirements-dev.lock b/requirements-dev.lock
index e9ea098..86eea12 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -69,7 +69,7 @@ pydantic-core==2.27.1
# via pydantic
pygments==2.18.0
# via rich
-pyright==1.1.392.post0
+pyright==1.1.399
pytest==8.3.3
# via pytest-asyncio
pytest-asyncio==0.24.0
diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py
index d55fecf..84db2c9 100644
--- a/src/replicate/_base_client.py
+++ b/src/replicate/_base_client.py
@@ -98,7 +98,11 @@
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
if TYPE_CHECKING:
- from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
+ from httpx._config import (
+ DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage]
+ )
+
+ HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG
else:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
@@ -115,6 +119,7 @@ class PageInfo:
url: URL | NotGiven
params: Query | NotGiven
+ json: Body | NotGiven
@overload
def __init__(
@@ -130,19 +135,30 @@ def __init__(
params: Query,
) -> None: ...
+ @overload
+ def __init__(
+ self,
+ *,
+ json: Body,
+ ) -> None: ...
+
def __init__(
self,
*,
url: URL | NotGiven = NOT_GIVEN,
+ json: Body | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
self.url = url
+ self.json = json
self.params = params
@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
+ if self.json:
+ return f"{self.__class__.__name__}(json={self.json})"
return f"{self.__class__.__name__}(params={self.params})"
@@ -191,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options.url = str(url)
return options
+ if not isinstance(info.json, NotGiven):
+ if not is_mapping(info.json):
+ raise TypeError("Pagination is only supported with mappings")
+
+ if not options.json_data:
+ options.json_data = {**info.json}
+ else:
+ if not is_mapping(options.json_data):
+ raise TypeError("Pagination is only supported with mappings")
+
+ options.json_data = {**options.json_data, **info.json}
+ return options
+
raise ValueError("Unexpected PageInfo state")
@@ -408,8 +437,7 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
headers = httpx.Headers(headers_dict)
idempotency_header = self._idempotency_header
- if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
- options.idempotency_key = options.idempotency_key or self._idempotency_key()
+ if idempotency_header and options.idempotency_key and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key
# Don't set these headers if they were already set or removed by the caller. We check
@@ -874,7 +902,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[True],
stream_cls: Type[_StreamT],
@@ -885,7 +912,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: Literal[False] = False,
) -> ResponseT: ...
@@ -895,7 +921,6 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: Type[_StreamT] | None = None,
@@ -905,125 +930,109 @@ def request(
self,
cast_to: Type[ResponseT],
options: FinalRequestOptions,
- remaining_retries: Optional[int] = None,
*,
stream: bool = False,
stream_cls: type[_StreamT] | None = None,
) -> ResponseT | _StreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
+ cast_to = self._maybe_override_cast_to(cast_to, options)
- def _request(
- self,
- *,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- retries_taken: int,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
-
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = self._prepare_options(options)
-
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- self._prepare_request(request)
-
- if options.idempotency_key:
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
# ensure the idempotency key is reused between requests
- input_options.idempotency_key = options.idempotency_key
+ input_options.idempotency_key = self._idempotency_key()
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- log.debug("Sending HTTP Request: %s %s", request.method, request.url)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = self._prepare_options(options)
- try:
- response = self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ self._prepare_request(request)
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- if remaining_retries > 0:
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
+ response = None
+ try:
+ response = self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
)
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
-
- log.debug(
- 'HTTP Response: %s %s "%i %s" %s',
- request.method,
- request.url,
- response.status_code,
- response.reason_phrase,
- response.headers,
- )
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ err.response.close()
+ self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- err.response.close()
- return self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ err.response.read()
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- err.response.read()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return self._process_response(
cast_to=cast_to,
options=options,
@@ -1033,37 +1042,20 @@ def _request(
retries_taken=retries_taken,
)
- def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_StreamT] | None,
- ) -> ResponseT | _StreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
- # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a
- # different thread if necessary.
time.sleep(timeout)
- return self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
def _process_response(
self,
*,
@@ -1407,7 +1399,6 @@ async def request(
options: FinalRequestOptions,
*,
stream: Literal[False] = False,
- remaining_retries: Optional[int] = None,
) -> ResponseT: ...
@overload
@@ -1418,7 +1409,6 @@ async def request(
*,
stream: Literal[True],
stream_cls: type[_AsyncStreamT],
- remaining_retries: Optional[int] = None,
) -> _AsyncStreamT: ...
@overload
@@ -1429,7 +1419,6 @@ async def request(
*,
stream: bool,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
) -> ResponseT | _AsyncStreamT: ...
async def request(
@@ -1439,120 +1428,111 @@ async def request(
*,
stream: bool = False,
stream_cls: type[_AsyncStreamT] | None = None,
- remaining_retries: Optional[int] = None,
- ) -> ResponseT | _AsyncStreamT:
- if remaining_retries is not None:
- retries_taken = options.get_max_retries(self.max_retries) - remaining_retries
- else:
- retries_taken = 0
-
- return await self._request(
- cast_to=cast_to,
- options=options,
- stream=stream,
- stream_cls=stream_cls,
- retries_taken=retries_taken,
- )
-
- async def _request(
- self,
- cast_to: Type[ResponseT],
- options: FinalRequestOptions,
- *,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- retries_taken: int,
) -> ResponseT | _AsyncStreamT:
if self._platform is None:
# `get_platform` can make blocking IO calls so we
# execute it earlier while we are in an async context
self._platform = await asyncify(get_platform)()
+ cast_to = self._maybe_override_cast_to(cast_to, options)
+
# create a copy of the options we were given so that if the
# options are mutated later & we then retry, the retries are
# given the original options
input_options = model_copy(options)
-
- cast_to = self._maybe_override_cast_to(cast_to, options)
- options = await self._prepare_options(options)
-
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
- request = self._build_request(options, retries_taken=retries_taken)
- await self._prepare_request(request)
-
- if options.idempotency_key:
+ if input_options.idempotency_key is None and input_options.method.lower() != "get":
# ensure the idempotency key is reused between requests
- input_options.idempotency_key = options.idempotency_key
+ input_options.idempotency_key = self._idempotency_key()
- kwargs: HttpxSendArgs = {}
- if self.custom_auth is not None:
- kwargs["auth"] = self.custom_auth
+ response: httpx.Response | None = None
+ max_retries = input_options.get_max_retries(self.max_retries)
- try:
- response = await self._client.send(
- request,
- stream=stream or self._should_stream_response_body(request=request),
- **kwargs,
- )
- except httpx.TimeoutException as err:
- log.debug("Encountered httpx.TimeoutException", exc_info=True)
+ retries_taken = 0
+ for retries_taken in range(max_retries + 1):
+ options = model_copy(input_options)
+ options = await self._prepare_options(options)
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ remaining_retries = max_retries - retries_taken
+ request = self._build_request(options, retries_taken=retries_taken)
+ await self._prepare_request(request)
- log.debug("Raising timeout error")
- raise APITimeoutError(request=request) from err
- except Exception as err:
- log.debug("Encountered Exception", exc_info=True)
+ kwargs: HttpxSendArgs = {}
+ if self.custom_auth is not None:
+ kwargs["auth"] = self.custom_auth
- if remaining_retries > 0:
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- stream=stream,
- stream_cls=stream_cls,
- response_headers=None,
- )
+ log.debug("Sending HTTP Request: %s %s", request.method, request.url)
- log.debug("Raising connection error")
- raise APIConnectionError(request=request) from err
+ response = None
+ try:
+ response = await self._client.send(
+ request,
+ stream=stream or self._should_stream_response_body(request=request),
+ **kwargs,
+ )
+ except httpx.TimeoutException as err:
+ log.debug("Encountered httpx.TimeoutException", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising timeout error")
+ raise APITimeoutError(request=request) from err
+ except Exception as err:
+ log.debug("Encountered Exception", exc_info=True)
+
+ if remaining_retries > 0:
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=None,
+ )
+ continue
+
+ log.debug("Raising connection error")
+ raise APIConnectionError(request=request) from err
+
+ log.debug(
+ 'HTTP Response: %s %s "%i %s" %s',
+ request.method,
+ request.url,
+ response.status_code,
+ response.reason_phrase,
+ response.headers,
+ )
- log.debug(
- 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase
- )
+ try:
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
+ log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
+
+ if remaining_retries > 0 and self._should_retry(err.response):
+ await err.response.aclose()
+ await self._sleep_for_retry(
+ retries_taken=retries_taken,
+ max_retries=max_retries,
+ options=input_options,
+ response=response,
+ )
+ continue
- try:
- response.raise_for_status()
- except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code
- log.debug("Encountered httpx.HTTPStatusError", exc_info=True)
-
- if remaining_retries > 0 and self._should_retry(err.response):
- await err.response.aclose()
- return await self._retry_request(
- input_options,
- cast_to,
- retries_taken=retries_taken,
- response_headers=err.response.headers,
- stream=stream,
- stream_cls=stream_cls,
- )
+ # If the response is streamed then we need to explicitly read the response
+ # to completion before attempting to access the response text.
+ if not err.response.is_closed:
+ await err.response.aread()
- # If the response is streamed then we need to explicitly read the response
- # to completion before attempting to access the response text.
- if not err.response.is_closed:
- await err.response.aread()
+ log.debug("Re-raising status error")
+ raise self._make_status_error_from_response(err.response) from None
- log.debug("Re-raising status error")
- raise self._make_status_error_from_response(err.response) from None
+ break
+ assert response is not None, "could not resolve response (should never happen)"
return await self._process_response(
cast_to=cast_to,
options=options,
@@ -1562,35 +1542,20 @@ async def _request(
retries_taken=retries_taken,
)
- async def _retry_request(
- self,
- options: FinalRequestOptions,
- cast_to: Type[ResponseT],
- *,
- retries_taken: int,
- response_headers: httpx.Headers | None,
- stream: bool,
- stream_cls: type[_AsyncStreamT] | None,
- ) -> ResponseT | _AsyncStreamT:
- remaining_retries = options.get_max_retries(self.max_retries) - retries_taken
+ async def _sleep_for_retry(
+ self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None
+ ) -> None:
+ remaining_retries = max_retries - retries_taken
if remaining_retries == 1:
log.debug("1 retry left")
else:
log.debug("%i retries left", remaining_retries)
- timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers)
+ timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None)
log.info("Retrying request to %s in %f seconds", options.url, timeout)
await anyio.sleep(timeout)
- return await self._request(
- options=options,
- cast_to=cast_to,
- retries_taken=retries_taken + 1,
- stream=stream,
- stream_cls=stream_cls,
- )
-
async def _process_response(
self,
*,
diff --git a/src/replicate/_client.py b/src/replicate/_client.py
index 06e86ff..6afa0ec 100644
--- a/src/replicate/_client.py
+++ b/src/replicate/_client.py
@@ -19,10 +19,7 @@
ProxiesTypes,
RequestOptions,
)
-from ._utils import (
- is_given,
- get_async_library,
-)
+from ._utils import is_given, get_async_library
from ._version import __version__
from .resources import accounts, hardware, trainings, collections, predictions
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
@@ -88,13 +85,13 @@ def __init__(
) -> None:
"""Construct a new synchronous ReplicateClient client instance.
- This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided.
+ This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
if bearer_token is None:
- bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN")
+ bearer_token = os.environ.get("REPLICATE_API_TOKEN")
if bearer_token is None:
raise ReplicateClientError(
- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable"
+ "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
)
self.bearer_token = bearer_token
@@ -270,13 +267,13 @@ def __init__(
) -> None:
"""Construct a new async AsyncReplicateClient client instance.
- This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided.
+ This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
"""
if bearer_token is None:
- bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN")
+ bearer_token = os.environ.get("REPLICATE_API_TOKEN")
if bearer_token is None:
raise ReplicateClientError(
- "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable"
+ "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
)
self.bearer_token = bearer_token
diff --git a/src/replicate/_models.py b/src/replicate/_models.py
index 3493571..798956f 100644
--- a/src/replicate/_models.py
+++ b/src/replicate/_models.py
@@ -19,7 +19,6 @@
)
import pydantic
-import pydantic.generics
from pydantic.fields import FieldInfo
from ._types import (
@@ -627,8 +626,8 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any,
# Note: if one variant defines an alias then they all should
discriminator_alias = field_info.alias
- if field_info.annotation and is_literal_type(field_info.annotation):
- for entry in get_args(field_info.annotation):
+ if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation):
+ for entry in get_args(annotation):
if isinstance(entry, str):
mapping[entry] = variant
diff --git a/src/replicate/_utils/_typing.py b/src/replicate/_utils/_typing.py
index 1958820..1bac954 100644
--- a/src/replicate/_utils/_typing.py
+++ b/src/replicate/_utils/_typing.py
@@ -110,7 +110,7 @@ class MyResponse(Foo[_T]):
```
"""
cls = cast(object, get_origin(typ) or typ)
- if cls in generic_bases:
+ if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains]
# we're given the class directly
return extract_type_arg(typ, index)
diff --git a/src/replicate/_utils/_utils.py b/src/replicate/_utils/_utils.py
index e5811bb..ea3cf3f 100644
--- a/src/replicate/_utils/_utils.py
+++ b/src/replicate/_utils/_utils.py
@@ -72,8 +72,16 @@ def _extract_items(
from .._files import assert_is_file_content
# We have exhausted the path, return the entry we found.
- assert_is_file_content(obj, key=flattened_key)
assert flattened_key is not None
+
+ if is_list(obj):
+ files: list[tuple[str, FileTypes]] = []
+ for entry in obj:
+ assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
+ files.append((flattened_key + "[]", cast(FileTypes, entry)))
+ return files
+
+ assert_is_file_content(obj, key=flattened_key)
return [(flattened_key, cast(FileTypes, obj))]
index += 1
diff --git a/src/replicate/_version.py b/src/replicate/_version.py
index dfeb99c..5b5e6c4 100644
--- a/src/replicate/_version.py
+++ b/src/replicate/_version.py
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
__title__ = "replicate"
-__version__ = "0.1.0-alpha.2" # x-release-please-version
+__version__ = "0.1.0-alpha.3" # x-release-please-version
diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py
index 0211b67..565a1c6 100644
--- a/src/replicate/resources/deployments/deployments.py
+++ b/src/replicate/resources/deployments/deployments.py
@@ -6,10 +6,7 @@
from ...types import deployment_create_params, deployment_update_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -28,10 +25,10 @@
)
from ...pagination import SyncCursorURLPage, AsyncCursorURLPage
from ..._base_client import AsyncPaginator, make_request_options
+from ...types.deployment_get_response import DeploymentGetResponse
from ...types.deployment_list_response import DeploymentListResponse
from ...types.deployment_create_response import DeploymentCreateResponse
from ...types.deployment_update_response import DeploymentUpdateResponse
-from ...types.deployment_retrieve_response import DeploymentRetrieveResponse
__all__ = ["DeploymentsResource", "AsyncDeploymentsResource"]
@@ -165,77 +162,6 @@ def create(
cast_to=DeploymentCreateResponse,
)
- def retrieve(
- self,
- deployment_name: str,
- *,
- deployment_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DeploymentRetrieveResponse:
- """
- Get information about a deployment by name including the current release.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
- ```
-
- The response will be a JSON object describing the deployment:
-
- ```json
- {
- "owner": "acme",
- "name": "my-app-image-generator",
- "current_release": {
- "number": 1,
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "created_at": "2024-02-15T16:32:57.018467Z",
- "created_by": {
- "type": "organization",
- "username": "acme",
- "name": "Acme Corp, Inc.",
- "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
- "github_url": "https://github.com/acme"
- },
- "configuration": {
- "hardware": "gpu-t4",
- "min_instances": 1,
- "max_instances": 5
- }
- }
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not deployment_owner:
- raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
- if not deployment_name:
- raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
- return self._get(
- f"/deployments/{deployment_owner}/{deployment_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=DeploymentRetrieveResponse,
- )
-
def update(
self,
deployment_name: str,
@@ -454,6 +380,77 @@ def delete(
cast_to=NoneType,
)
+ def get(
+ self,
+ deployment_name: str,
+ *,
+ deployment_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> DeploymentGetResponse:
+ """
+ Get information about a deployment by name including the current release.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
+ ```
+
+ The response will be a JSON object describing the deployment:
+
+ ```json
+ {
+ "owner": "acme",
+ "name": "my-app-image-generator",
+ "current_release": {
+ "number": 1,
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "created_at": "2024-02-15T16:32:57.018467Z",
+ "created_by": {
+ "type": "organization",
+ "username": "acme",
+ "name": "Acme Corp, Inc.",
+ "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
+ "github_url": "https://github.com/acme"
+ },
+ "configuration": {
+ "hardware": "gpu-t4",
+ "min_instances": 1,
+ "max_instances": 5
+ }
+ }
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not deployment_owner:
+ raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
+ if not deployment_name:
+ raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
+ return self._get(
+ f"/deployments/{deployment_owner}/{deployment_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=DeploymentGetResponse,
+ )
+
def list_em_all(
self,
*,
@@ -628,77 +625,6 @@ async def create(
cast_to=DeploymentCreateResponse,
)
- async def retrieve(
- self,
- deployment_name: str,
- *,
- deployment_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> DeploymentRetrieveResponse:
- """
- Get information about a deployment by name including the current release.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
- ```
-
- The response will be a JSON object describing the deployment:
-
- ```json
- {
- "owner": "acme",
- "name": "my-app-image-generator",
- "current_release": {
- "number": 1,
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
- "created_at": "2024-02-15T16:32:57.018467Z",
- "created_by": {
- "type": "organization",
- "username": "acme",
- "name": "Acme Corp, Inc.",
- "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
- "github_url": "https://github.com/acme"
- },
- "configuration": {
- "hardware": "gpu-t4",
- "min_instances": 1,
- "max_instances": 5
- }
- }
- }
- ```
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not deployment_owner:
- raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
- if not deployment_name:
- raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
- return await self._get(
- f"/deployments/{deployment_owner}/{deployment_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=DeploymentRetrieveResponse,
- )
-
async def update(
self,
deployment_name: str,
@@ -917,6 +843,77 @@ async def delete(
cast_to=NoneType,
)
+ async def get(
+ self,
+ deployment_name: str,
+ *,
+ deployment_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> DeploymentGetResponse:
+ """
+ Get information about a deployment by name including the current release.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/deployments/replicate/my-app-image-generator
+ ```
+
+ The response will be a JSON object describing the deployment:
+
+ ```json
+ {
+ "owner": "acme",
+ "name": "my-app-image-generator",
+ "current_release": {
+ "number": 1,
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf",
+ "created_at": "2024-02-15T16:32:57.018467Z",
+ "created_by": {
+ "type": "organization",
+ "username": "acme",
+ "name": "Acme Corp, Inc.",
+ "avatar_url": "https://cdn.replicate.com/avatars/acme.png",
+ "github_url": "https://github.com/acme"
+ },
+ "configuration": {
+ "hardware": "gpu-t4",
+ "min_instances": 1,
+ "max_instances": 5
+ }
+ }
+ }
+ ```
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not deployment_owner:
+ raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}")
+ if not deployment_name:
+ raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}")
+ return await self._get(
+ f"/deployments/{deployment_owner}/{deployment_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=DeploymentGetResponse,
+ )
+
async def list_em_all(
self,
*,
@@ -969,9 +966,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.create = to_raw_response_wrapper(
deployments.create,
)
- self.retrieve = to_raw_response_wrapper(
- deployments.retrieve,
- )
self.update = to_raw_response_wrapper(
deployments.update,
)
@@ -981,6 +975,9 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.delete = to_raw_response_wrapper(
deployments.delete,
)
+ self.get = to_raw_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = to_raw_response_wrapper(
deployments.list_em_all,
)
@@ -997,9 +994,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.create = async_to_raw_response_wrapper(
deployments.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- deployments.retrieve,
- )
self.update = async_to_raw_response_wrapper(
deployments.update,
)
@@ -1009,6 +1003,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.delete = async_to_raw_response_wrapper(
deployments.delete,
)
+ self.get = async_to_raw_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = async_to_raw_response_wrapper(
deployments.list_em_all,
)
@@ -1025,9 +1022,6 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.create = to_streamed_response_wrapper(
deployments.create,
)
- self.retrieve = to_streamed_response_wrapper(
- deployments.retrieve,
- )
self.update = to_streamed_response_wrapper(
deployments.update,
)
@@ -1037,6 +1031,9 @@ def __init__(self, deployments: DeploymentsResource) -> None:
self.delete = to_streamed_response_wrapper(
deployments.delete,
)
+ self.get = to_streamed_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = to_streamed_response_wrapper(
deployments.list_em_all,
)
@@ -1053,9 +1050,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.create = async_to_streamed_response_wrapper(
deployments.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- deployments.retrieve,
- )
self.update = async_to_streamed_response_wrapper(
deployments.update,
)
@@ -1065,6 +1059,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None:
self.delete = async_to_streamed_response_wrapper(
deployments.delete,
)
+ self.get = async_to_streamed_response_wrapper(
+ deployments.get,
+ )
self.list_em_all = async_to_streamed_response_wrapper(
deployments.list_em_all,
)
diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py
index 7177332..c94fff0 100644
--- a/src/replicate/resources/deployments/predictions.py
+++ b/src/replicate/resources/deployments/predictions.py
@@ -8,11 +8,7 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven
-from ..._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py
index 77145ae..435101f 100644
--- a/src/replicate/resources/models/models.py
+++ b/src/replicate/resources/models/models.py
@@ -9,11 +9,7 @@
from ...types import model_create_params, model_create_prediction_params
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, strip_not_given, async_maybe_transform
from .versions import (
VersionsResource,
AsyncVersionsResource,
@@ -174,115 +170,6 @@ def create(
cast_to=NoneType,
)
- def retrieve(
- self,
- model_name: str,
- *,
- model_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world
- ```
-
- The response will be a model object in the following format:
-
- ```json
- {
- "url": "https://replicate.com/replicate/hello-world",
- "owner": "replicate",
- "name": "hello-world",
- "description": "A tiny model that says hello",
- "visibility": "public",
- "github_url": "https://github.com/replicate/cog-examples",
- "paper_url": null,
- "license_url": null,
- "run_count": 5681081,
- "cover_image_url": "...",
- "default_example": {...},
- "latest_version": {...},
- }
- ```
-
- The model object includes the
- [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
- for the latest version of the model.
-
- Here's an example showing how to fetch the model with cURL and display its input
- schema with [jq](https://stedolan.github.io/jq/):
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world \\
- | jq ".latest_version.openapi_schema.components.schemas.Input"
- ```
-
- This will return the following JSON object:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "type": "string",
- "title": "Text",
- "x-order": 0,
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `cover_image_url` string is an HTTPS URL for an image file. This can be:
-
- - An image uploaded by the model author.
- - The output file of the example prediction, if the model author has not set a
- cover image.
- - The input file of the example prediction, if the model author has not set a
- cover image and the example prediction has no output file.
- - A generic fallback image.
-
- The `default_example` object is a [prediction](#predictions.get) created with
- this model.
-
- The `latest_version` object is the model's most recently pushed
- [version](#models.versions.get).
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/models/{model_owner}/{model_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
*,
@@ -514,6 +401,115 @@ def create_prediction(
cast_to=Prediction,
)
+ def get(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world
+ ```
+
+ The response will be a model object in the following format:
+
+ ```json
+ {
+ "url": "https://replicate.com/replicate/hello-world",
+ "owner": "replicate",
+ "name": "hello-world",
+ "description": "A tiny model that says hello",
+ "visibility": "public",
+ "github_url": "https://github.com/replicate/cog-examples",
+ "paper_url": null,
+ "license_url": null,
+ "run_count": 5681081,
+ "cover_image_url": "...",
+ "default_example": {...},
+ "latest_version": {...},
+ }
+ ```
+
+ The model object includes the
+ [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
+ for the latest version of the model.
+
+ Here's an example showing how to fetch the model with cURL and display its input
+ schema with [jq](https://stedolan.github.io/jq/):
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world \\
+ | jq ".latest_version.openapi_schema.components.schemas.Input"
+ ```
+
+ This will return the following JSON object:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "type": "string",
+ "title": "Text",
+ "x-order": 0,
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `cover_image_url` string is an HTTPS URL for an image file. This can be:
+
+ - An image uploaded by the model author.
+ - The output file of the example prediction, if the model author has not set a
+ cover image.
+ - The input file of the example prediction, if the model author has not set a
+ cover image and the example prediction has no output file.
+ - A generic fallback image.
+
+ The `default_example` object is a [prediction](#predictions.get) created with
+ this model.
+
+ The `latest_version` object is the model's most recently pushed
+ [version](#models.versions.get).
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return self._get(
+ f"/models/{model_owner}/{model_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class AsyncModelsResource(AsyncAPIResource):
@cached_property
@@ -651,115 +647,6 @@ async def create(
cast_to=NoneType,
)
- async def retrieve(
- self,
- model_name: str,
- *,
- model_owner: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world
- ```
-
- The response will be a model object in the following format:
-
- ```json
- {
- "url": "https://replicate.com/replicate/hello-world",
- "owner": "replicate",
- "name": "hello-world",
- "description": "A tiny model that says hello",
- "visibility": "public",
- "github_url": "https://github.com/replicate/cog-examples",
- "paper_url": null,
- "license_url": null,
- "run_count": 5681081,
- "cover_image_url": "...",
- "default_example": {...},
- "latest_version": {...},
- }
- ```
-
- The model object includes the
- [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
- for the latest version of the model.
-
- Here's an example showing how to fetch the model with cURL and display its input
- schema with [jq](https://stedolan.github.io/jq/):
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world \\
- | jq ".latest_version.openapi_schema.components.schemas.Input"
- ```
-
- This will return the following JSON object:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "type": "string",
- "title": "Text",
- "x-order": 0,
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `cover_image_url` string is an HTTPS URL for an image file. This can be:
-
- - An image uploaded by the model author.
- - The output file of the example prediction, if the model author has not set a
- cover image.
- - The input file of the example prediction, if the model author has not set a
- cover image and the example prediction has no output file.
- - A generic fallback image.
-
- The `default_example` object is a [prediction](#predictions.get) created with
- this model.
-
- The `latest_version` object is the model's most recently pushed
- [version](#models.versions.get).
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
- f"/models/{model_owner}/{model_name}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
*,
@@ -991,6 +878,115 @@ async def create_prediction(
cast_to=Prediction,
)
+ async def get(
+ self,
+ model_name: str,
+ *,
+ model_owner: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world
+ ```
+
+ The response will be a model object in the following format:
+
+ ```json
+ {
+ "url": "https://replicate.com/replicate/hello-world",
+ "owner": "replicate",
+ "name": "hello-world",
+ "description": "A tiny model that says hello",
+ "visibility": "public",
+ "github_url": "https://github.com/replicate/cog-examples",
+ "paper_url": null,
+ "license_url": null,
+ "run_count": 5681081,
+ "cover_image_url": "...",
+ "default_example": {...},
+ "latest_version": {...},
+ }
+ ```
+
+ The model object includes the
+ [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas)
+ for the latest version of the model.
+
+ Here's an example showing how to fetch the model with cURL and display its input
+ schema with [jq](https://stedolan.github.io/jq/):
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world \\
+ | jq ".latest_version.openapi_schema.components.schemas.Input"
+ ```
+
+ This will return the following JSON object:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "type": "string",
+ "title": "Text",
+ "x-order": 0,
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `cover_image_url` string is an HTTPS URL for an image file. This can be:
+
+ - An image uploaded by the model author.
+ - The output file of the example prediction, if the model author has not set a
+ cover image.
+ - The input file of the example prediction, if the model author has not set a
+ cover image and the example prediction has no output file.
+ - A generic fallback image.
+
+ The `default_example` object is a [prediction](#predictions.get) created with
+ this model.
+
+ The `latest_version` object is the model's most recently pushed
+ [version](#models.versions.get).
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._get(
+ f"/models/{model_owner}/{model_name}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=NoneType,
+ )
+
class ModelsResourceWithRawResponse:
def __init__(self, models: ModelsResource) -> None:
@@ -999,9 +995,6 @@ def __init__(self, models: ModelsResource) -> None:
self.create = to_raw_response_wrapper(
models.create,
)
- self.retrieve = to_raw_response_wrapper(
- models.retrieve,
- )
self.list = to_raw_response_wrapper(
models.list,
)
@@ -1011,6 +1004,9 @@ def __init__(self, models: ModelsResource) -> None:
self.create_prediction = to_raw_response_wrapper(
models.create_prediction,
)
+ self.get = to_raw_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> VersionsResourceWithRawResponse:
@@ -1024,9 +1020,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create = async_to_raw_response_wrapper(
models.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- models.retrieve,
- )
self.list = async_to_raw_response_wrapper(
models.list,
)
@@ -1036,6 +1029,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create_prediction = async_to_raw_response_wrapper(
models.create_prediction,
)
+ self.get = async_to_raw_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> AsyncVersionsResourceWithRawResponse:
@@ -1049,9 +1045,6 @@ def __init__(self, models: ModelsResource) -> None:
self.create = to_streamed_response_wrapper(
models.create,
)
- self.retrieve = to_streamed_response_wrapper(
- models.retrieve,
- )
self.list = to_streamed_response_wrapper(
models.list,
)
@@ -1061,6 +1054,9 @@ def __init__(self, models: ModelsResource) -> None:
self.create_prediction = to_streamed_response_wrapper(
models.create_prediction,
)
+ self.get = to_streamed_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> VersionsResourceWithStreamingResponse:
@@ -1074,9 +1070,6 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create = async_to_streamed_response_wrapper(
models.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- models.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
models.list,
)
@@ -1086,6 +1079,9 @@ def __init__(self, models: AsyncModelsResource) -> None:
self.create_prediction = async_to_streamed_response_wrapper(
models.create_prediction,
)
+ self.get = async_to_streamed_response_wrapper(
+ models.get,
+ )
@cached_property
def versions(self) -> AsyncVersionsResourceWithStreamingResponse:
diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py
index eb411ef..308d04b 100644
--- a/src/replicate/resources/models/versions.py
+++ b/src/replicate/resources/models/versions.py
@@ -8,10 +8,7 @@
import httpx
from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from ..._utils import (
- maybe_transform,
- async_maybe_transform,
-)
+from ..._utils import maybe_transform, async_maybe_transform
from ..._compat import cached_property
from ..._resource import SyncAPIResource, AsyncAPIResource
from ..._response import (
@@ -22,6 +19,7 @@
)
from ..._base_client import make_request_options
from ...types.models import version_create_training_params
+from ...types.models.version_create_training_response import VersionCreateTrainingResponse
__all__ = ["VersionsResource", "AsyncVersionsResource"]
@@ -46,101 +44,6 @@ def with_streaming_response(self) -> VersionsResourceWithStreamingResponse:
"""
return VersionsResourceWithStreamingResponse(self)
- def retrieve(
- self,
- version_id: str,
- *,
- model_owner: str,
- model_name: str,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa
- ```
-
- The response will be the version object:
-
- ```json
- {
- "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "created_at": "2022-04-26T19:29:04.418669Z",
- "cog_version": "0.3.0",
- "openapi_schema": {...}
- }
- ```
-
- Every model describes its inputs and outputs with
- [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject)
- in the `openapi_schema` property.
-
- The `openapi_schema.components.schemas.Input` property for the
- [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
- like this:
-
- ```json
- {
- "type": "object",
- "title": "Input",
- "required": ["text"],
- "properties": {
- "text": {
- "x-order": 0,
- "type": "string",
- "title": "Text",
- "description": "Text to prefix with 'hello '"
- }
- }
- }
- ```
-
- The `openapi_schema.components.schemas.Output` property for the
- [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
- like this:
-
- ```json
- {
- "type": "string",
- "title": "Output"
- }
- ```
-
- For more details, see the docs on
- [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types)
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not model_owner:
- raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
- if not model_name:
- raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
- if not version_id:
- raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/models/{model_owner}/{model_name}/versions/{version_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
model_name: str,
@@ -281,7 +184,7 @@ def create_training(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> VersionCreateTrainingResponse:
"""
Start a new training of the model version you specify.
@@ -398,7 +301,6 @@ def create_training(
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return self._post(
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
body=maybe_transform(
@@ -413,31 +315,10 @@ def create_training(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=VersionCreateTrainingResponse,
)
-
-class AsyncVersionsResource(AsyncAPIResource):
- @cached_property
- def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
- """
- return AsyncVersionsResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
- """
- return AsyncVersionsResourceWithStreamingResponse(self)
-
- async def retrieve(
+ def get(
self,
version_id: str,
*,
@@ -524,7 +405,7 @@ async def retrieve(
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
+ return self._get(
f"/models/{model_owner}/{model_name}/versions/{version_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
@@ -532,6 +413,27 @@ async def retrieve(
cast_to=NoneType,
)
+
+class AsyncVersionsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncVersionsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncVersionsResourceWithStreamingResponse(self)
+
async def list(
self,
model_name: str,
@@ -672,7 +574,7 @@ async def create_training(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> VersionCreateTrainingResponse:
"""
Start a new training of the model version you specify.
@@ -789,7 +691,6 @@ async def create_training(
raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
if not version_id:
raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return await self._post(
f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings",
body=await async_maybe_transform(
@@ -804,6 +705,101 @@ async def create_training(
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
+ cast_to=VersionCreateTrainingResponse,
+ )
+
+ async def get(
+ self,
+ version_id: str,
+ *,
+ model_owner: str,
+ model_name: str,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> None:
+ """
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa
+ ```
+
+ The response will be the version object:
+
+ ```json
+ {
+ "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "created_at": "2022-04-26T19:29:04.418669Z",
+ "cog_version": "0.3.0",
+ "openapi_schema": {...}
+ }
+ ```
+
+ Every model describes its inputs and outputs with
+ [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject)
+ in the `openapi_schema` property.
+
+ The `openapi_schema.components.schemas.Input` property for the
+ [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
+ like this:
+
+ ```json
+ {
+ "type": "object",
+ "title": "Input",
+ "required": ["text"],
+ "properties": {
+ "text": {
+ "x-order": 0,
+ "type": "string",
+ "title": "Text",
+ "description": "Text to prefix with 'hello '"
+ }
+ }
+ }
+ ```
+
+ The `openapi_schema.components.schemas.Output` property for the
+ [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks
+ like this:
+
+ ```json
+ {
+ "type": "string",
+ "title": "Output"
+ }
+ ```
+
+ For more details, see the docs on
+ [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types)
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not model_owner:
+ raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
+ if not model_name:
+ raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}")
+ if not version_id:
+ raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}")
+ extra_headers = {"Accept": "*/*", **(extra_headers or {})}
+ return await self._get(
+ f"/models/{model_owner}/{model_name}/versions/{version_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
cast_to=NoneType,
)
@@ -812,9 +808,6 @@ class VersionsResourceWithRawResponse:
def __init__(self, versions: VersionsResource) -> None:
self._versions = versions
- self.retrieve = to_raw_response_wrapper(
- versions.retrieve,
- )
self.list = to_raw_response_wrapper(
versions.list,
)
@@ -824,15 +817,15 @@ def __init__(self, versions: VersionsResource) -> None:
self.create_training = to_raw_response_wrapper(
versions.create_training,
)
+ self.get = to_raw_response_wrapper(
+ versions.get,
+ )
class AsyncVersionsResourceWithRawResponse:
def __init__(self, versions: AsyncVersionsResource) -> None:
self._versions = versions
- self.retrieve = async_to_raw_response_wrapper(
- versions.retrieve,
- )
self.list = async_to_raw_response_wrapper(
versions.list,
)
@@ -842,15 +835,15 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.create_training = async_to_raw_response_wrapper(
versions.create_training,
)
+ self.get = async_to_raw_response_wrapper(
+ versions.get,
+ )
class VersionsResourceWithStreamingResponse:
def __init__(self, versions: VersionsResource) -> None:
self._versions = versions
- self.retrieve = to_streamed_response_wrapper(
- versions.retrieve,
- )
self.list = to_streamed_response_wrapper(
versions.list,
)
@@ -860,15 +853,15 @@ def __init__(self, versions: VersionsResource) -> None:
self.create_training = to_streamed_response_wrapper(
versions.create_training,
)
+ self.get = to_streamed_response_wrapper(
+ versions.get,
+ )
class AsyncVersionsResourceWithStreamingResponse:
def __init__(self, versions: AsyncVersionsResource) -> None:
self._versions = versions
- self.retrieve = async_to_streamed_response_wrapper(
- versions.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
versions.list,
)
@@ -878,3 +871,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None:
self.create_training = async_to_streamed_response_wrapper(
versions.create_training,
)
+ self.get = async_to_streamed_response_wrapper(
+ versions.get,
+ )
diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py
index d6a1245..abe0ca1 100644
--- a/src/replicate/resources/predictions.py
+++ b/src/replicate/resources/predictions.py
@@ -10,11 +10,7 @@
from ..types import prediction_list_params, prediction_create_params
from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
-from .._utils import (
- maybe_transform,
- strip_not_given,
- async_maybe_transform,
-)
+from .._utils import maybe_transform, strip_not_given, async_maybe_transform
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
@@ -189,108 +185,6 @@ def create(
cast_to=Prediction,
)
- def retrieve(
- self,
- prediction_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Get the current state of a prediction.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
- ```
-
- The response will be the prediction object:
-
- ```json
- {
- "id": "gm3qorzdhgbfurvjtvhg6dckhu",
- "model": "replicate/hello-world",
- "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "input": {
- "text": "Alice"
- },
- "logs": "",
- "output": "hello Alice",
- "error": null,
- "status": "succeeded",
- "created_at": "2023-09-08T16:19:34.765994Z",
- "data_removed": false,
- "started_at": "2023-09-08T16:19:34.779176Z",
- "completed_at": "2023-09-08T16:19:34.791859Z",
- "metrics": {
- "predict_time": 0.012683
- },
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
- "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
- }
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the prediction is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the prediction.
- - `processing`: the `predict()` method of the model is currently running.
- - `succeeded`: the prediction completed successfully.
- - `failed`: the prediction encountered an error during processing.
- - `canceled`: the prediction was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- prediction.
-
- Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the prediction used while running.
- It won't include time waiting for the prediction to start.
-
- All input parameters, output values, and logs are automatically removed after an
- hour, by default, for predictions created through the API.
-
- You must save a copy of any data or files in the output if you'd like to
- continue using them. The `output` key will still be present, but it's value will
- be `null` after the output has been removed.
-
- Output files are served by `replicate.delivery` and its subdomains. If you use
- an allow list of external domains for your assets, add `replicate.delivery` and
- `*.replicate.delivery` to it.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not prediction_id:
- raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
- return self._get(
- f"/predictions/{prediction_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
def list(
self,
*,
@@ -440,6 +334,108 @@ def cancel(
cast_to=NoneType,
)
+ def get(
+ self,
+ prediction_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Get the current state of a prediction.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
+ ```
+
+ The response will be the prediction object:
+
+ ```json
+ {
+ "id": "gm3qorzdhgbfurvjtvhg6dckhu",
+ "model": "replicate/hello-world",
+ "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "input": {
+ "text": "Alice"
+ },
+ "logs": "",
+ "output": "hello Alice",
+ "error": null,
+ "status": "succeeded",
+ "created_at": "2023-09-08T16:19:34.765994Z",
+ "data_removed": false,
+ "started_at": "2023-09-08T16:19:34.779176Z",
+ "completed_at": "2023-09-08T16:19:34.791859Z",
+ "metrics": {
+ "predict_time": 0.012683
+ },
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
+ "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
+ }
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the prediction is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the prediction.
+ - `processing`: the `predict()` method of the model is currently running.
+ - `succeeded`: the prediction completed successfully.
+ - `failed`: the prediction encountered an error during processing.
+ - `canceled`: the prediction was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ prediction.
+
+ Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the prediction used while running.
+ It won't include time waiting for the prediction to start.
+
+ All input parameters, output values, and logs are automatically removed after an
+ hour, by default, for predictions created through the API.
+
+ You must save a copy of any data or files in the output if you'd like to
+ continue using them. The `output` key will still be present, but it's value will
+ be `null` after the output has been removed.
+
+ Output files are served by `replicate.delivery` and its subdomains. If you use
+ an allow list of external domains for your assets, add `replicate.delivery` and
+ `*.replicate.delivery` to it.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not prediction_id:
+ raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
+ return self._get(
+ f"/predictions/{prediction_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
class AsyncPredictionsResource(AsyncAPIResource):
@cached_property
@@ -600,108 +596,6 @@ async def create(
cast_to=Prediction,
)
- async def retrieve(
- self,
- prediction_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> Prediction:
- """
- Get the current state of a prediction.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
- ```
-
- The response will be the prediction object:
-
- ```json
- {
- "id": "gm3qorzdhgbfurvjtvhg6dckhu",
- "model": "replicate/hello-world",
- "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
- "input": {
- "text": "Alice"
- },
- "logs": "",
- "output": "hello Alice",
- "error": null,
- "status": "succeeded",
- "created_at": "2023-09-08T16:19:34.765994Z",
- "data_removed": false,
- "started_at": "2023-09-08T16:19:34.779176Z",
- "completed_at": "2023-09-08T16:19:34.791859Z",
- "metrics": {
- "predict_time": 0.012683
- },
- "urls": {
- "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
- "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
- }
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the prediction is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the prediction.
- - `processing`: the `predict()` method of the model is currently running.
- - `succeeded`: the prediction completed successfully.
- - `failed`: the prediction encountered an error during processing.
- - `canceled`: the prediction was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- prediction.
-
- Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the prediction used while running.
- It won't include time waiting for the prediction to start.
-
- All input parameters, output values, and logs are automatically removed after an
- hour, by default, for predictions created through the API.
-
- You must save a copy of any data or files in the output if you'd like to
- continue using them. The `output` key will still be present, but it's value will
- be `null` after the output has been removed.
-
- Output files are served by `replicate.delivery` and its subdomains. If you use
- an allow list of external domains for your assets, add `replicate.delivery` and
- `*.replicate.delivery` to it.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not prediction_id:
- raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
- return await self._get(
- f"/predictions/{prediction_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=Prediction,
- )
-
def list(
self,
*,
@@ -851,6 +745,108 @@ async def cancel(
cast_to=NoneType,
)
+ async def get(
+ self,
+ prediction_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> Prediction:
+ """
+ Get the current state of a prediction.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu
+ ```
+
+ The response will be the prediction object:
+
+ ```json
+ {
+ "id": "gm3qorzdhgbfurvjtvhg6dckhu",
+ "model": "replicate/hello-world",
+ "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
+ "input": {
+ "text": "Alice"
+ },
+ "logs": "",
+ "output": "hello Alice",
+ "error": null,
+ "status": "succeeded",
+ "created_at": "2023-09-08T16:19:34.765994Z",
+ "data_removed": false,
+ "started_at": "2023-09-08T16:19:34.779176Z",
+ "completed_at": "2023-09-08T16:19:34.791859Z",
+ "metrics": {
+ "predict_time": 0.012683
+ },
+ "urls": {
+ "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel",
+ "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu"
+ }
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the prediction is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the prediction.
+ - `processing`: the `predict()` method of the model is currently running.
+ - `succeeded`: the prediction completed successfully.
+ - `failed`: the prediction encountered an error during processing.
+ - `canceled`: the prediction was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ prediction.
+
+ Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the prediction used while running.
+ It won't include time waiting for the prediction to start.
+
+ All input parameters, output values, and logs are automatically removed after an
+ hour, by default, for predictions created through the API.
+
+ You must save a copy of any data or files in the output if you'd like to
+ continue using them. The `output` key will still be present, but it's value will
+ be `null` after the output has been removed.
+
+ Output files are served by `replicate.delivery` and its subdomains. If you use
+ an allow list of external domains for your assets, add `replicate.delivery` and
+ `*.replicate.delivery` to it.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not prediction_id:
+ raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}")
+ return await self._get(
+ f"/predictions/{prediction_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=Prediction,
+ )
+
class PredictionsResourceWithRawResponse:
def __init__(self, predictions: PredictionsResource) -> None:
@@ -859,15 +855,15 @@ def __init__(self, predictions: PredictionsResource) -> None:
self.create = to_raw_response_wrapper(
predictions.create,
)
- self.retrieve = to_raw_response_wrapper(
- predictions.retrieve,
- )
self.list = to_raw_response_wrapper(
predictions.list,
)
self.cancel = to_raw_response_wrapper(
predictions.cancel,
)
+ self.get = to_raw_response_wrapper(
+ predictions.get,
+ )
class AsyncPredictionsResourceWithRawResponse:
@@ -877,15 +873,15 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None:
self.create = async_to_raw_response_wrapper(
predictions.create,
)
- self.retrieve = async_to_raw_response_wrapper(
- predictions.retrieve,
- )
self.list = async_to_raw_response_wrapper(
predictions.list,
)
self.cancel = async_to_raw_response_wrapper(
predictions.cancel,
)
+ self.get = async_to_raw_response_wrapper(
+ predictions.get,
+ )
class PredictionsResourceWithStreamingResponse:
@@ -895,15 +891,15 @@ def __init__(self, predictions: PredictionsResource) -> None:
self.create = to_streamed_response_wrapper(
predictions.create,
)
- self.retrieve = to_streamed_response_wrapper(
- predictions.retrieve,
- )
self.list = to_streamed_response_wrapper(
predictions.list,
)
self.cancel = to_streamed_response_wrapper(
predictions.cancel,
)
+ self.get = to_streamed_response_wrapper(
+ predictions.get,
+ )
class AsyncPredictionsResourceWithStreamingResponse:
@@ -913,12 +909,12 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None:
self.create = async_to_streamed_response_wrapper(
predictions.create,
)
- self.retrieve = async_to_streamed_response_wrapper(
- predictions.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
predictions.list,
)
self.cancel = async_to_streamed_response_wrapper(
predictions.cancel,
)
+ self.get = async_to_streamed_response_wrapper(
+ predictions.get,
+ )
diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py
index df64c19..5a357af 100644
--- a/src/replicate/resources/trainings.py
+++ b/src/replicate/resources/trainings.py
@@ -4,7 +4,7 @@
import httpx
-from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven
+from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._compat import cached_property
from .._resource import SyncAPIResource, AsyncAPIResource
from .._response import (
@@ -13,7 +13,11 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
-from .._base_client import make_request_options
+from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
+from .._base_client import AsyncPaginator, make_request_options
+from ..types.training_get_response import TrainingGetResponse
+from ..types.training_list_response import TrainingListResponse
+from ..types.training_cancel_response import TrainingCancelResponse
__all__ = ["TrainingsResource", "AsyncTrainingsResource"]
@@ -38,100 +42,6 @@ def with_streaming_response(self) -> TrainingsResourceWithStreamingResponse:
"""
return TrainingsResourceWithStreamingResponse(self)
- def retrieve(
- self,
- training_id: str,
- *,
- # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
- # The extra values given here take precedence over values defined on the client or passed to this method.
- extra_headers: Headers | None = None,
- extra_query: Query | None = None,
- extra_body: Body | None = None,
- timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
- """
- Get the current state of a training.
-
- Example cURL request:
-
- ```console
- curl -s \\
- -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
- https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga
- ```
-
- The response will be the training object:
-
- ```json
- {
- "completed_at": "2023-09-08T16:41:19.826523Z",
- "created_at": "2023-09-08T16:32:57.018467Z",
- "error": null,
- "id": "zz4ibbonubfz7carwiefibzgga",
- "input": {
- "input_images": "https://example.com/my-input-images.zip"
- },
- "logs": "...",
- "metrics": {
- "predict_time": 502.713876
- },
- "output": {
- "version": "...",
- "weights": "..."
- },
- "started_at": "2023-09-08T16:32:57.112647Z",
- "status": "succeeded",
- "urls": {
- "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
- "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
- },
- "model": "stability-ai/sdxl",
- "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
- }
- ```
-
- `status` will be one of:
-
- - `starting`: the training is starting up. If this status lasts longer than a
- few seconds, then it's typically because a new worker is being started to run
- the training.
- - `processing`: the `train()` method of the model is currently running.
- - `succeeded`: the training completed successfully.
- - `failed`: the training encountered an error during processing.
- - `canceled`: the training was canceled by its creator.
-
- In the case of success, `output` will be an object containing the output of the
- model. Any files will be represented as HTTPS URLs. You'll need to pass the
- `Authorization` header to request them.
-
- In the case of failure, `error` will contain the error encountered during the
- training.
-
- Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`)
- will include a `metrics` object with a `predict_time` property showing the
- amount of CPU or GPU time, in seconds, that the training used while running. It
- won't include time waiting for the training to start.
-
- Args:
- extra_headers: Send extra headers
-
- extra_query: Add additional query parameters to the request
-
- extra_body: Add additional JSON properties to the request
-
- timeout: Override the client-level default timeout for this request, in seconds
- """
- if not training_id:
- raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
- f"/trainings/{training_id}",
- options=make_request_options(
- extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
- ),
- cast_to=NoneType,
- )
-
def list(
self,
*,
@@ -141,7 +51,7 @@ def list(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> SyncCursorURLPage[TrainingListResponse]:
"""
Get a paginated list of all trainings created by the user or organization
associated with the provided API token.
@@ -207,13 +117,13 @@ def list(
`version` will be the unique ID of model version used to create the training.
"""
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return self._get(
+ return self._get_api_list(
"/trainings",
+ page=SyncCursorURLPage[TrainingListResponse],
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ model=TrainingListResponse,
)
def cancel(
@@ -226,7 +136,7 @@ def cancel(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingCancelResponse:
"""
Cancel a training
@@ -241,37 +151,15 @@ def cancel(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return self._post(
f"/trainings/{training_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingCancelResponse,
)
-
-class AsyncTrainingsResource(AsyncAPIResource):
- @cached_property
- def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse:
- """
- This property can be used as a prefix for any HTTP method call to return
- the raw response object instead of the parsed content.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
- """
- return AsyncTrainingsResourceWithRawResponse(self)
-
- @cached_property
- def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse:
- """
- An alternative to `.with_raw_response` that doesn't eagerly read the response body.
-
- For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
- """
- return AsyncTrainingsResourceWithStreamingResponse(self)
-
- async def retrieve(
+ def get(
self,
training_id: str,
*,
@@ -281,7 +169,7 @@ async def retrieve(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingGetResponse:
"""
Get the current state of a training.
@@ -356,16 +244,36 @@ async def retrieve(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
+ return self._get(
f"/trainings/{training_id}",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingGetResponse,
)
- async def list(
+
+class AsyncTrainingsResource(AsyncAPIResource):
+ @cached_property
+ def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse:
+ """
+ This property can be used as a prefix for any HTTP method call to return
+ the raw response object instead of the parsed content.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers
+ """
+ return AsyncTrainingsResourceWithRawResponse(self)
+
+ @cached_property
+ def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse:
+ """
+ An alternative to `.with_raw_response` that doesn't eagerly read the response body.
+
+ For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response
+ """
+ return AsyncTrainingsResourceWithStreamingResponse(self)
+
+ def list(
self,
*,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
@@ -374,7 +282,7 @@ async def list(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> AsyncPaginator[TrainingListResponse, AsyncCursorURLPage[TrainingListResponse]]:
"""
Get a paginated list of all trainings created by the user or organization
associated with the provided API token.
@@ -440,13 +348,13 @@ async def list(
`version` will be the unique ID of model version used to create the training.
"""
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
- return await self._get(
+ return self._get_api_list(
"/trainings",
+ page=AsyncCursorURLPage[TrainingListResponse],
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ model=TrainingListResponse,
)
async def cancel(
@@ -459,7 +367,7 @@ async def cancel(
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
- ) -> None:
+ ) -> TrainingCancelResponse:
"""
Cancel a training
@@ -474,13 +382,105 @@ async def cancel(
"""
if not training_id:
raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
- extra_headers = {"Accept": "*/*", **(extra_headers or {})}
return await self._post(
f"/trainings/{training_id}/cancel",
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
- cast_to=NoneType,
+ cast_to=TrainingCancelResponse,
+ )
+
+ async def get(
+ self,
+ training_id: str,
+ *,
+ # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
+ # The extra values given here take precedence over values defined on the client or passed to this method.
+ extra_headers: Headers | None = None,
+ extra_query: Query | None = None,
+ extra_body: Body | None = None,
+ timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
+ ) -> TrainingGetResponse:
+ """
+ Get the current state of a training.
+
+ Example cURL request:
+
+ ```console
+ curl -s \\
+ -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\
+ https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga
+ ```
+
+ The response will be the training object:
+
+ ```json
+ {
+ "completed_at": "2023-09-08T16:41:19.826523Z",
+ "created_at": "2023-09-08T16:32:57.018467Z",
+ "error": null,
+ "id": "zz4ibbonubfz7carwiefibzgga",
+ "input": {
+ "input_images": "https://example.com/my-input-images.zip"
+ },
+ "logs": "...",
+ "metrics": {
+ "predict_time": 502.713876
+ },
+ "output": {
+ "version": "...",
+ "weights": "..."
+ },
+ "started_at": "2023-09-08T16:32:57.112647Z",
+ "status": "succeeded",
+ "urls": {
+ "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga",
+ "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel"
+ },
+ "model": "stability-ai/sdxl",
+ "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf"
+ }
+ ```
+
+ `status` will be one of:
+
+ - `starting`: the training is starting up. If this status lasts longer than a
+ few seconds, then it's typically because a new worker is being started to run
+ the training.
+ - `processing`: the `train()` method of the model is currently running.
+ - `succeeded`: the training completed successfully.
+ - `failed`: the training encountered an error during processing.
+ - `canceled`: the training was canceled by its creator.
+
+ In the case of success, `output` will be an object containing the output of the
+ model. Any files will be represented as HTTPS URLs. You'll need to pass the
+ `Authorization` header to request them.
+
+ In the case of failure, `error` will contain the error encountered during the
+ training.
+
+ Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`)
+ will include a `metrics` object with a `predict_time` property showing the
+ amount of CPU or GPU time, in seconds, that the training used while running. It
+ won't include time waiting for the training to start.
+
+ Args:
+ extra_headers: Send extra headers
+
+ extra_query: Add additional query parameters to the request
+
+ extra_body: Add additional JSON properties to the request
+
+ timeout: Override the client-level default timeout for this request, in seconds
+ """
+ if not training_id:
+ raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}")
+ return await self._get(
+ f"/trainings/{training_id}",
+ options=make_request_options(
+ extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
+ ),
+ cast_to=TrainingGetResponse,
)
@@ -488,57 +488,57 @@ class TrainingsResourceWithRawResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = to_raw_response_wrapper(
- trainings.retrieve,
- )
self.list = to_raw_response_wrapper(
trainings.list,
)
self.cancel = to_raw_response_wrapper(
trainings.cancel,
)
+ self.get = to_raw_response_wrapper(
+ trainings.get,
+ )
class AsyncTrainingsResourceWithRawResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = async_to_raw_response_wrapper(
- trainings.retrieve,
- )
self.list = async_to_raw_response_wrapper(
trainings.list,
)
self.cancel = async_to_raw_response_wrapper(
trainings.cancel,
)
+ self.get = async_to_raw_response_wrapper(
+ trainings.get,
+ )
class TrainingsResourceWithStreamingResponse:
def __init__(self, trainings: TrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = to_streamed_response_wrapper(
- trainings.retrieve,
- )
self.list = to_streamed_response_wrapper(
trainings.list,
)
self.cancel = to_streamed_response_wrapper(
trainings.cancel,
)
+ self.get = to_streamed_response_wrapper(
+ trainings.get,
+ )
class AsyncTrainingsResourceWithStreamingResponse:
def __init__(self, trainings: AsyncTrainingsResource) -> None:
self._trainings = trainings
- self.retrieve = async_to_streamed_response_wrapper(
- trainings.retrieve,
- )
self.list = async_to_streamed_response_wrapper(
trainings.list,
)
self.cancel = async_to_streamed_response_wrapper(
trainings.cancel,
)
+ self.get = async_to_streamed_response_wrapper(
+ trainings.get,
+ )
diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py
index e2b3c58..fa8ee5d 100644
--- a/src/replicate/types/__init__.py
+++ b/src/replicate/types/__init__.py
@@ -7,13 +7,16 @@
from .model_create_params import ModelCreateParams as ModelCreateParams
from .model_list_response import ModelListResponse as ModelListResponse
from .account_list_response import AccountListResponse as AccountListResponse
+from .training_get_response import TrainingGetResponse as TrainingGetResponse
from .hardware_list_response import HardwareListResponse as HardwareListResponse
from .prediction_list_params import PredictionListParams as PredictionListParams
+from .training_list_response import TrainingListResponse as TrainingListResponse
+from .deployment_get_response import DeploymentGetResponse as DeploymentGetResponse
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
+from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse
-from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse
from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams
diff --git a/src/replicate/types/deployment_retrieve_response.py b/src/replicate/types/deployment_get_response.py
similarity index 91%
rename from src/replicate/types/deployment_retrieve_response.py
rename to src/replicate/types/deployment_get_response.py
index 2ecc13c..a7281ff 100644
--- a/src/replicate/types/deployment_retrieve_response.py
+++ b/src/replicate/types/deployment_get_response.py
@@ -6,7 +6,7 @@
from .._models import BaseModel
-__all__ = ["DeploymentRetrieveResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"]
+__all__ = ["DeploymentGetResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"]
class CurrentReleaseConfiguration(BaseModel):
@@ -55,7 +55,7 @@ class CurrentRelease(BaseModel):
"""The ID of the model version used in the release."""
-class DeploymentRetrieveResponse(BaseModel):
+class DeploymentGetResponse(BaseModel):
current_release: Optional[CurrentRelease] = None
name: Optional[str] = None
diff --git a/src/replicate/types/deployment_list_response.py b/src/replicate/types/deployment_list_response.py
index 9f64487..2606d38 100644
--- a/src/replicate/types/deployment_list_response.py
+++ b/src/replicate/types/deployment_list_response.py
@@ -49,11 +49,7 @@ class CurrentRelease(BaseModel):
"""The model identifier string in the format of `{model_owner}/{model_name}`."""
number: Optional[int] = None
- """The release number.
-
- This is an auto-incrementing integer that starts at 1, and is set automatically
- when a deployment is created.
- """
+ """The release number."""
version: Optional[str] = None
"""The ID of the model version used in the release."""
diff --git a/src/replicate/types/models/__init__.py b/src/replicate/types/models/__init__.py
index 98e1a3a..2d89bc6 100644
--- a/src/replicate/types/models/__init__.py
+++ b/src/replicate/types/models/__init__.py
@@ -3,3 +3,4 @@
from __future__ import annotations
from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams
+from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse
diff --git a/src/replicate/types/models/version_create_training_response.py b/src/replicate/types/models/version_create_training_response.py
new file mode 100644
index 0000000..5b005fe
--- /dev/null
+++ b/src/replicate/types/models/version_create_training_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from ..._models import BaseModel
+
+__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class VersionCreateTrainingResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/prediction_output.py b/src/replicate/types/prediction_output.py
index 2c46100..8b320b5 100644
--- a/src/replicate/types/prediction_output.py
+++ b/src/replicate/types/prediction_output.py
@@ -6,5 +6,5 @@
__all__ = ["PredictionOutput"]
PredictionOutput: TypeAlias = Union[
- Optional[Dict[str, object]], Optional[List[object]], Optional[str], Optional[float], Optional[bool]
+ Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[str], Optional[float], Optional[bool]
]
diff --git a/src/replicate/types/training_cancel_response.py b/src/replicate/types/training_cancel_response.py
new file mode 100644
index 0000000..9af512b
--- /dev/null
+++ b/src/replicate/types/training_cancel_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingCancelResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingCancelResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/training_get_response.py b/src/replicate/types/training_get_response.py
new file mode 100644
index 0000000..4169da7
--- /dev/null
+++ b/src/replicate/types/training_get_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingGetResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingGetResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/src/replicate/types/training_list_response.py b/src/replicate/types/training_list_response.py
new file mode 100644
index 0000000..02adf3a
--- /dev/null
+++ b/src/replicate/types/training_list_response.py
@@ -0,0 +1,74 @@
+# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
+
+from typing import Dict, Optional
+from datetime import datetime
+from typing_extensions import Literal
+
+from .._models import BaseModel
+
+__all__ = ["TrainingListResponse", "Metrics", "Output", "URLs"]
+
+
+class Metrics(BaseModel):
+ predict_time: Optional[float] = None
+ """The amount of CPU or GPU time, in seconds, that the training used while running"""
+
+
+class Output(BaseModel):
+ version: Optional[str] = None
+ """The version of the model created by the training"""
+
+ weights: Optional[str] = None
+ """The weights of the trained model"""
+
+
+class URLs(BaseModel):
+ cancel: Optional[str] = None
+ """URL to cancel the training"""
+
+ get: Optional[str] = None
+ """URL to get the training details"""
+
+
+class TrainingListResponse(BaseModel):
+ id: Optional[str] = None
+ """The unique ID of the training"""
+
+ completed_at: Optional[datetime] = None
+ """The time when the training completed"""
+
+ created_at: Optional[datetime] = None
+ """The time when the training was created"""
+
+ error: Optional[str] = None
+ """Error message if the training failed"""
+
+ input: Optional[Dict[str, object]] = None
+ """The input parameters used for the training"""
+
+ logs: Optional[str] = None
+ """The logs from the training process"""
+
+ metrics: Optional[Metrics] = None
+ """Metrics about the training process"""
+
+ model: Optional[str] = None
+ """The name of the model in the format owner/name"""
+
+ output: Optional[Output] = None
+ """The output of the training process"""
+
+ source: Optional[Literal["web", "api"]] = None
+ """How the training was created"""
+
+ started_at: Optional[datetime] = None
+ """The time when the training started"""
+
+ status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None
+ """The current status of the training"""
+
+ urls: Optional[URLs] = None
+ """URLs for interacting with the training"""
+
+ version: Optional[str] = None
+ """The ID of the model version used for training"""
diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py
index 295205e..d1fb7a8 100644
--- a/tests/api_resources/models/test_versions.py
+++ b/tests/api_resources/models/test_versions.py
@@ -8,6 +8,8 @@
import pytest
from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types.models import VersionCreateTrainingResponse
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -15,70 +17,6 @@
class TestVersions:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- version = client.models.versions.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- )
- assert version is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = response.parse()
- assert version is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.models.versions.with_streaming_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="model_name",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- version = response.parse()
- assert version is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="",
- model_name="model_name",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="version_id",
- model_owner="model_owner",
- model_name="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- client.models.versions.with_raw_response.retrieve(
- version_id="",
- model_owner="model_owner",
- model_name="model_name",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -205,7 +143,7 @@ def test_method_create_training(self, client: ReplicateClient) -> None:
destination="destination",
input={},
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -219,7 +157,7 @@ def test_method_create_training_with_all_params(self, client: ReplicateClient) -
webhook="webhook",
webhook_events_filter=["start"],
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -235,7 +173,7 @@ def test_raw_response_create_training(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -251,7 +189,7 @@ def test_streaming_response_create_training(self, client: ReplicateClient) -> No
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -285,14 +223,10 @@ def test_path_params_create_training(self, client: ReplicateClient) -> None:
input={},
)
-
-class TestAsyncVersions:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
@pytest.mark.skip()
@parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- version = await async_client.models.versions.retrieve(
+ def test_method_get(self, client: ReplicateClient) -> None:
+ version = client.models.versions.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -301,8 +235,8 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None
@pytest.mark.skip()
@parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.versions.with_raw_response.retrieve(
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -310,13 +244,13 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = await response.parse()
+ version = response.parse()
assert version is None
@pytest.mark.skip()
@parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.versions.with_streaming_response.retrieve(
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.models.versions.with_streaming_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="model_name",
@@ -324,35 +258,39 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- version = await response.parse()
+ version = response.parse()
assert version is None
assert cast(Any, response.is_closed) is True
@pytest.mark.skip()
@parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
+ def test_path_params_get(self, client: ReplicateClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="",
model_name="model_name",
)
with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="version_id",
model_owner="model_owner",
model_name="",
)
with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
- await async_client.models.versions.with_raw_response.retrieve(
+ client.models.versions.with_raw_response.get(
version_id="",
model_owner="model_owner",
model_name="model_name",
)
+
+class TestAsyncVersions:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -479,7 +417,7 @@ async def test_method_create_training(self, async_client: AsyncReplicateClient)
destination="destination",
input={},
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -493,7 +431,7 @@ async def test_method_create_training_with_all_params(self, async_client: AsyncR
webhook="webhook",
webhook_events_filter=["start"],
)
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -509,7 +447,7 @@ async def test_raw_response_create_training(self, async_client: AsyncReplicateCl
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = await response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -525,7 +463,7 @@ async def test_streaming_response_create_training(self, async_client: AsyncRepli
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
version = await response.parse()
- assert version is None
+ assert_matches_type(VersionCreateTrainingResponse, version, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -558,3 +496,67 @@ async def test_path_params_create_training(self, async_client: AsyncReplicateCli
destination="destination",
input={},
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ version = await async_client.models.versions.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
+ assert version is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ version = await response.parse()
+ assert version is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.versions.with_streaming_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="model_name",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ version = await response.parse()
+ assert version is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="",
+ model_name="model_name",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="version_id",
+ model_owner="model_owner",
+ model_name="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"):
+ await async_client.models.versions.with_raw_response.get(
+ version_id="",
+ model_owner="model_owner",
+ model_name="model_name",
+ )
diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py
index b834241..3a01dcb 100644
--- a/tests/api_resources/test_deployments.py
+++ b/tests/api_resources/test_deployments.py
@@ -10,10 +10,10 @@
from replicate import ReplicateClient, AsyncReplicateClient
from tests.utils import assert_matches_type
from replicate.types import (
+ DeploymentGetResponse,
DeploymentListResponse,
DeploymentCreateResponse,
DeploymentUpdateResponse,
- DeploymentRetrieveResponse,
)
from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
@@ -72,58 +72,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- deployment = client.deployments.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.deployments.with_streaming_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
- client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
- client.deployments.with_raw_response.retrieve(
- deployment_name="",
- deployment_owner="deployment_owner",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_update(self, client: ReplicateClient) -> None:
@@ -269,6 +217,58 @@ def test_path_params_delete(self, client: ReplicateClient) -> None:
deployment_owner="deployment_owner",
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ deployment = client.deployments.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ deployment = response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.deployments.with_streaming_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ deployment = response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
+ client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
+ client.deployments.with_raw_response.get(
+ deployment_name="",
+ deployment_owner="deployment_owner",
+ )
+
@pytest.mark.skip()
@parametrize
def test_method_list_em_all(self, client: ReplicateClient) -> None:
@@ -350,58 +350,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- deployment = await async_client.deployments.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- deployment = await response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.deployments.with_streaming_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="deployment_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- deployment = await response.parse()
- assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
- await async_client.deployments.with_raw_response.retrieve(
- deployment_name="deployment_name",
- deployment_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
- await async_client.deployments.with_raw_response.retrieve(
- deployment_name="",
- deployment_owner="deployment_owner",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_update(self, async_client: AsyncReplicateClient) -> None:
@@ -547,6 +495,58 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N
deployment_owner="deployment_owner",
)
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ deployment = await async_client.deployments.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ deployment = await response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.deployments.with_streaming_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="deployment_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ deployment = await response.parse()
+ assert_matches_type(DeploymentGetResponse, deployment, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"):
+ await async_client.deployments.with_raw_response.get(
+ deployment_name="deployment_name",
+ deployment_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"):
+ await async_client.deployments.with_raw_response.get(
+ deployment_name="",
+ deployment_owner="deployment_owner",
+ )
+
@pytest.mark.skip()
@parametrize
async def test_method_list_em_all(self, async_client: AsyncReplicateClient) -> None:
diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py
index f56f00c..a2a9f52 100644
--- a/tests/api_resources/test_models.py
+++ b/tests/api_resources/test_models.py
@@ -77,58 +77,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- model = client.models.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = response.parse()
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.models.with_streaming_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = response.parse()
- assert model is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- client.models.with_raw_response.retrieve(
- model_name="",
- model_owner="model_owner",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -280,6 +228,58 @@ def test_path_params_create_prediction(self, client: ReplicateClient) -> None:
input={},
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ model = client.models.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = response.parse()
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.models.with_streaming_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ client.models.with_raw_response.get(
+ model_name="",
+ model_owner="model_owner",
+ )
+
class TestAsyncModels:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -343,58 +343,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- model = await async_client.models.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- model = await response.parse()
- assert model is None
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.models.with_streaming_response.retrieve(
- model_name="model_name",
- model_owner="model_owner",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- model = await response.parse()
- assert model is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
- await async_client.models.with_raw_response.retrieve(
- model_name="model_name",
- model_owner="",
- )
-
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
- await async_client.models.with_raw_response.retrieve(
- model_name="",
- model_owner="model_owner",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -545,3 +493,55 @@ async def test_path_params_create_prediction(self, async_client: AsyncReplicateC
model_owner="model_owner",
input={},
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ model = await async_client.models.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ model = await response.parse()
+ assert model is None
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.models.with_streaming_response.get(
+ model_name="model_name",
+ model_owner="model_owner",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ model = await response.parse()
+ assert model is None
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"):
+ await async_client.models.with_raw_response.get(
+ model_name="model_name",
+ model_owner="",
+ )
+
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"):
+ await async_client.models.with_raw_response.get(
+ model_name="",
+ model_owner="model_owner",
+ )
diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py
index 8fc4142..51bd80d 100644
--- a/tests/api_resources/test_predictions.py
+++ b/tests/api_resources/test_predictions.py
@@ -69,48 +69,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None:
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- prediction = client.predictions.retrieve(
- "prediction_id",
- )
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.predictions.with_raw_response.retrieve(
- "prediction_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- prediction = response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.predictions.with_streaming_response.retrieve(
- "prediction_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- prediction = response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
- client.predictions.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
@@ -190,6 +148,48 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None:
"",
)
+ @pytest.mark.skip()
+ @parametrize
+ def test_method_get(self, client: ReplicateClient) -> None:
+ prediction = client.predictions.get(
+ "prediction_id",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.predictions.with_raw_response.get(
+ "prediction_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.predictions.with_streaming_response.get(
+ "prediction_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ def test_path_params_get(self, client: ReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
+ client.predictions.with_raw_response.get(
+ "",
+ )
+
class TestAsyncPredictions:
parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
@@ -244,48 +244,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien
assert cast(Any, response.is_closed) is True
- @pytest.mark.skip()
- @parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- prediction = await async_client.predictions.retrieve(
- "prediction_id",
- )
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.predictions.with_raw_response.retrieve(
- "prediction_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- prediction = await response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- @pytest.mark.skip()
- @parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.predictions.with_streaming_response.retrieve(
- "prediction_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- prediction = await response.parse()
- assert_matches_type(Prediction, prediction, path=["response"])
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
- await async_client.predictions.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
@@ -364,3 +322,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N
await async_client.predictions.with_raw_response.cancel(
"",
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ prediction = await async_client.predictions.get(
+ "prediction_id",
+ )
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.predictions.with_raw_response.get(
+ "prediction_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.predictions.with_streaming_response.get(
+ "prediction_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ prediction = await response.parse()
+ assert_matches_type(Prediction, prediction, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"):
+ await async_client.predictions.with_raw_response.get(
+ "",
+ )
diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py
index f59d874..f2dadb1 100644
--- a/tests/api_resources/test_trainings.py
+++ b/tests/api_resources/test_trainings.py
@@ -8,6 +8,9 @@
import pytest
from replicate import ReplicateClient, AsyncReplicateClient
+from tests.utils import assert_matches_type
+from replicate.types import TrainingGetResponse, TrainingListResponse, TrainingCancelResponse
+from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
@@ -15,53 +18,11 @@
class TestTrainings:
parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"])
- @pytest.mark.skip()
- @parametrize
- def test_method_retrieve(self, client: ReplicateClient) -> None:
- training = client.trainings.retrieve(
- "training_id",
- )
- assert training is None
-
- @pytest.mark.skip()
- @parametrize
- def test_raw_response_retrieve(self, client: ReplicateClient) -> None:
- response = client.trainings.with_raw_response.retrieve(
- "training_id",
- )
-
- assert response.is_closed is True
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = response.parse()
- assert training is None
-
- @pytest.mark.skip()
- @parametrize
- def test_streaming_response_retrieve(self, client: ReplicateClient) -> None:
- with client.trainings.with_streaming_response.retrieve(
- "training_id",
- ) as response:
- assert not response.is_closed
- assert response.http_request.headers.get("X-Stainless-Lang") == "python"
-
- training = response.parse()
- assert training is None
-
- assert cast(Any, response.is_closed) is True
-
- @pytest.mark.skip()
- @parametrize
- def test_path_params_retrieve(self, client: ReplicateClient) -> None:
- with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
- client.trainings.with_raw_response.retrieve(
- "",
- )
-
@pytest.mark.skip()
@parametrize
def test_method_list(self, client: ReplicateClient) -> None:
training = client.trainings.list()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -71,7 +32,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -81,7 +42,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -91,7 +52,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None:
training = client.trainings.cancel(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -103,7 +64,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None:
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -115,7 +76,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None:
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -127,57 +88,57 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None:
"",
)
-
-class TestAsyncTrainings:
- parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
-
@pytest.mark.skip()
@parametrize
- async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None:
- training = await async_client.trainings.retrieve(
+ def test_method_get(self, client: ReplicateClient) -> None:
+ training = client.trainings.get(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
- async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- response = await async_client.trainings.with_raw_response.retrieve(
+ def test_raw_response_get(self, client: ReplicateClient) -> None:
+ response = client.trainings.with_raw_response.get(
"training_id",
)
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = await response.parse()
- assert training is None
+ training = response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
- async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None:
- async with async_client.trainings.with_streaming_response.retrieve(
+ def test_streaming_response_get(self, client: ReplicateClient) -> None:
+ with client.trainings.with_streaming_response.get(
"training_id",
) as response:
assert not response.is_closed
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
- training = await response.parse()
- assert training is None
+ training = response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@pytest.mark.skip()
@parametrize
- async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None:
+ def test_path_params_get(self, client: ReplicateClient) -> None:
with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
- await async_client.trainings.with_raw_response.retrieve(
+ client.trainings.with_raw_response.get(
"",
)
+
+class TestAsyncTrainings:
+ parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"])
+
@pytest.mark.skip()
@parametrize
async def test_method_list(self, async_client: AsyncReplicateClient) -> None:
training = await async_client.trainings.list()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -187,7 +148,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -197,7 +158,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient)
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -207,7 +168,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None:
training = await async_client.trainings.cancel(
"training_id",
)
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -219,7 +180,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) ->
assert response.is_closed is True
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
@pytest.mark.skip()
@parametrize
@@ -231,7 +192,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
training = await response.parse()
- assert training is None
+ assert_matches_type(TrainingCancelResponse, training, path=["response"])
assert cast(Any, response.is_closed) is True
@@ -242,3 +203,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N
await async_client.trainings.with_raw_response.cancel(
"",
)
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_method_get(self, async_client: AsyncReplicateClient) -> None:
+ training = await async_client.trainings.get(
+ "training_id",
+ )
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None:
+ response = await async_client.trainings.with_raw_response.get(
+ "training_id",
+ )
+
+ assert response.is_closed is True
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+ training = await response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None:
+ async with async_client.trainings.with_streaming_response.get(
+ "training_id",
+ ) as response:
+ assert not response.is_closed
+ assert response.http_request.headers.get("X-Stainless-Lang") == "python"
+
+ training = await response.parse()
+ assert_matches_type(TrainingGetResponse, training, path=["response"])
+
+ assert cast(Any, response.is_closed) is True
+
+ @pytest.mark.skip()
+ @parametrize
+ async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None:
+ with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"):
+ await async_client.trainings.with_raw_response.get(
+ "",
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
index 97007f1..7d1538b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -10,7 +10,7 @@
from replicate import ReplicateClient, AsyncReplicateClient
if TYPE_CHECKING:
- from _pytest.fixtures import FixtureRequest
+ from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage]
pytest.register_assert_rewrite("tests.utils")
diff --git a/tests/test_client.py b/tests/test_client.py
index bd5b999..279675f 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -346,7 +346,7 @@ def test_validate_headers(self) -> None:
assert request.headers.get("Authorization") == f"Bearer {bearer_token}"
with pytest.raises(ReplicateClientError):
- with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}):
+ with update_env(**{"REPLICATE_API_TOKEN": Omit()}):
client2 = ReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True)
_ = client2
@@ -1126,7 +1126,7 @@ def test_validate_headers(self) -> None:
assert request.headers.get("Authorization") == f"Bearer {bearer_token}"
with pytest.raises(ReplicateClientError):
- with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}):
+ with update_env(**{"REPLICATE_API_TOKEN": Omit()}):
client2 = AsyncReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True)
_ = client2
diff --git a/tests/test_models.py b/tests/test_models.py
index ee374fa..cf8173c 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -492,12 +492,15 @@ class Model(BaseModel):
resource_id: Optional[str] = None
m = Model.construct()
+ assert m.resource_id is None
assert "resource_id" not in m.model_fields_set
m = Model.construct(resource_id=None)
+ assert m.resource_id is None
assert "resource_id" in m.model_fields_set
m = Model.construct(resource_id="foo")
+ assert m.resource_id == "foo"
assert "resource_id" in m.model_fields_set
@@ -832,7 +835,7 @@ class B(BaseModel):
@pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1")
def test_type_alias_type() -> None:
- Alias = TypeAliasType("Alias", str)
+ Alias = TypeAliasType("Alias", str) # pyright: ignore
class Model(BaseModel):
alias: Alias