diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81f6dc2..04b083c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,6 +10,7 @@ on: jobs: lint: + timeout-minutes: 10 name: lint runs-on: ubuntu-latest steps: @@ -30,6 +31,7 @@ jobs: run: ./scripts/lint test: + timeout-minutes: 10 name: test runs-on: ubuntu-latest steps: diff --git a/.release-please-manifest.json b/.release-please-manifest.json index f14b480..aaf968a 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.2" + ".": "0.1.0-alpha.3" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 91aadf3..190fe36 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 27 -openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-37bb31ed76da599d3bded543a3765f745c8575d105c13554df7f8361c3641482.yml -openapi_spec_hash: 15bdec12ca84042768bfb28cc48dfce3 -config_hash: 810de4c2eee1a7649263cff01f00da7c +openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-2788217b7ad7d61d1a77800bc5ff12a6810f1692d4d770b72fa8f898c6a055ab.yml +openapi_spec_hash: 4423bf747e228484547b441468a9f156 +config_hash: d1d273c0d97d034d24c7eac8ef51d2ac diff --git a/CHANGELOG.md b/CHANGELOG.md index d51ac7f..6bde67f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,34 @@ # Changelog +## 0.1.0-alpha.3 (2025-04-23) + +Full Changelog: [v0.1.0-alpha.2...v0.1.0-alpha.3](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.2...v0.1.0-alpha.3) + +### ⚠ BREAKING CHANGES + +* **api:** use correct env var for bearer token + +### Features + +* **api:** api update ([7ebd598](https://github.com/replicate/replicate-python-stainless/commit/7ebd59873181c74dbaa035ac599abcbbefb3ee62)) + + +### Bug Fixes + +* **api:** use correct env var for bearer token ([00eab77](https://github.com/replicate/replicate-python-stainless/commit/00eab7702f8f2699ce9b3070f23202278ac21855)) +* **pydantic v1:** more robust ModelField.annotation check ([c907599](https://github.com/replicate/replicate-python-stainless/commit/c907599a6736e781f3f80062eb4d03ed92f03403)) + + +### Chores + +* **ci:** add timeout thresholds for CI jobs ([1bad4d3](https://github.com/replicate/replicate-python-stainless/commit/1bad4d3d3676a323032f37f0195ff640fcce3458)) +* **internal:** base client updates ([c1d6ed5](https://github.com/replicate/replicate-python-stainless/commit/c1d6ed59ed0f06012922ec6d0bae376852523d81)) +* **internal:** bump pyright version ([f1e4d14](https://github.com/replicate/replicate-python-stainless/commit/f1e4d140104ff317b94cb2dd88ec850a9b8bce54)) +* **internal:** fix list file params ([2918eba](https://github.com/replicate/replicate-python-stainless/commit/2918ebad39df868485fed02a2d0020bef72d24b9)) +* **internal:** import reformatting ([4cdf515](https://github.com/replicate/replicate-python-stainless/commit/4cdf515372a9e936c3a18afd24a444a778b1f7f5)) +* **internal:** refactor retries to not use recursion ([75005e1](https://github.com/replicate/replicate-python-stainless/commit/75005e11045385d0596911bbbbb062207450bd14)) +* **internal:** update models test ([fc34c6d](https://github.com/replicate/replicate-python-stainless/commit/fc34c6d4fc36a41441ab8417f85343e640b53b76)) + ## 0.1.0-alpha.2 (2025-04-16) Full Changelog: [v0.1.0-alpha.1...v0.1.0-alpha.2](https://github.com/replicate/replicate-python-stainless/compare/v0.1.0-alpha.1...v0.1.0-alpha.2) diff --git a/README.md b/README.md index 0015b4c..23b57ae 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,7 @@ import os from replicate import ReplicateClient client = ReplicateClient( - bearer_token=os.environ.get( - "REPLICATE_CLIENT_BEARER_TOKEN" - ), # This is the default and can be omitted + bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) accounts = client.accounts.list() @@ -39,7 +37,7 @@ print(accounts.type) While you can provide a `bearer_token` keyword argument, we recommend using [python-dotenv](https://pypi.org/project/python-dotenv/) -to add `REPLICATE_CLIENT_BEARER_TOKEN="My Bearer Token"` to your `.env` file +to add `REPLICATE_API_TOKEN="My Bearer Token"` to your `.env` file so that your Bearer Token is not stored in source control. ## Async usage @@ -52,9 +50,7 @@ import asyncio from replicate import AsyncReplicateClient client = AsyncReplicateClient( - bearer_token=os.environ.get( - "REPLICATE_CLIENT_BEARER_TOKEN" - ), # This is the default and can be omitted + bearer_token=os.environ.get("REPLICATE_API_TOKEN"), # This is the default and can be omitted ) diff --git a/api.md b/api.md index 60e91e6..e972297 100644 --- a/api.md +++ b/api.md @@ -11,19 +11,19 @@ Types: ```python from replicate.types import ( DeploymentCreateResponse, - DeploymentRetrieveResponse, DeploymentUpdateResponse, DeploymentListResponse, + DeploymentGetResponse, ) ``` Methods: - client.deployments.create(\*\*params) -> DeploymentCreateResponse -- client.deployments.retrieve(deployment_name, \*, deployment_owner) -> DeploymentRetrieveResponse - client.deployments.update(deployment_name, \*, deployment_owner, \*\*params) -> DeploymentUpdateResponse - client.deployments.list() -> SyncCursorURLPage[DeploymentListResponse] - client.deployments.delete(deployment_name, \*, deployment_owner) -> None +- client.deployments.get(deployment_name, \*, deployment_owner) -> DeploymentGetResponse - client.deployments.list_em_all() -> None ## Predictions @@ -68,19 +68,25 @@ from replicate.types import ModelListResponse Methods: - client.models.create(\*\*params) -> None -- client.models.retrieve(model_name, \*, model_owner) -> None - client.models.list() -> SyncCursorURLPage[ModelListResponse] - client.models.delete(model_name, \*, model_owner) -> None - client.models.create_prediction(model_name, \*, model_owner, \*\*params) -> Prediction +- client.models.get(model_name, \*, model_owner) -> None ## Versions +Types: + +```python +from replicate.types.models import VersionCreateTrainingResponse +``` + Methods: -- client.models.versions.retrieve(version_id, \*, model_owner, model_name) -> None - client.models.versions.list(model_name, \*, model_owner) -> None - client.models.versions.delete(version_id, \*, model_owner, model_name) -> None -- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> None +- client.models.versions.create_training(version_id, \*, model_owner, model_name, \*\*params) -> VersionCreateTrainingResponse +- client.models.versions.get(version_id, \*, model_owner, model_name) -> None # Predictions @@ -93,17 +99,23 @@ from replicate.types import Prediction, PredictionOutput, PredictionRequest Methods: - client.predictions.create(\*\*params) -> Prediction -- client.predictions.retrieve(prediction_id) -> Prediction - client.predictions.list(\*\*params) -> SyncCursorURLPageWithCreatedFilters[Prediction] - client.predictions.cancel(prediction_id) -> None +- client.predictions.get(prediction_id) -> Prediction # Trainings +Types: + +```python +from replicate.types import TrainingListResponse, TrainingCancelResponse, TrainingGetResponse +``` + Methods: -- client.trainings.retrieve(training_id) -> None -- client.trainings.list() -> None -- client.trainings.cancel(training_id) -> None +- client.trainings.list() -> SyncCursorURLPage[TrainingListResponse] +- client.trainings.cancel(training_id) -> TrainingCancelResponse +- client.trainings.get(training_id) -> TrainingGetResponse # Webhooks diff --git a/pyproject.toml b/pyproject.toml index d9681a8..4376f13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate-stainless" -version = "0.1.0-alpha.2" +version = "0.1.0-alpha.3" description = "The official Python library for the replicate-client API" dynamic = ["readme"] license = "Apache-2.0" @@ -42,7 +42,7 @@ Repository = "https://github.com/replicate/replicate-python-stainless" managed = true # version pins are in requirements-dev.lock dev-dependencies = [ - "pyright>=1.1.359", + "pyright==1.1.399", "mypy", "respx", "pytest", diff --git a/requirements-dev.lock b/requirements-dev.lock index e9ea098..86eea12 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -69,7 +69,7 @@ pydantic-core==2.27.1 # via pydantic pygments==2.18.0 # via rich -pyright==1.1.392.post0 +pyright==1.1.399 pytest==8.3.3 # via pytest-asyncio pytest-asyncio==0.24.0 diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index d55fecf..84db2c9 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -98,7 +98,11 @@ _AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any]) if TYPE_CHECKING: - from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + from httpx._config import ( + DEFAULT_TIMEOUT_CONFIG, # pyright: ignore[reportPrivateImportUsage] + ) + + HTTPX_DEFAULT_TIMEOUT = DEFAULT_TIMEOUT_CONFIG else: try: from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT @@ -115,6 +119,7 @@ class PageInfo: url: URL | NotGiven params: Query | NotGiven + json: Body | NotGiven @overload def __init__( @@ -130,19 +135,30 @@ def __init__( params: Query, ) -> None: ... + @overload + def __init__( + self, + *, + json: Body, + ) -> None: ... + def __init__( self, *, url: URL | NotGiven = NOT_GIVEN, + json: Body | NotGiven = NOT_GIVEN, params: Query | NotGiven = NOT_GIVEN, ) -> None: self.url = url + self.json = json self.params = params @override def __repr__(self) -> str: if self.url: return f"{self.__class__.__name__}(url={self.url})" + if self.json: + return f"{self.__class__.__name__}(json={self.json})" return f"{self.__class__.__name__}(params={self.params})" @@ -191,6 +207,19 @@ def _info_to_options(self, info: PageInfo) -> FinalRequestOptions: options.url = str(url) return options + if not isinstance(info.json, NotGiven): + if not is_mapping(info.json): + raise TypeError("Pagination is only supported with mappings") + + if not options.json_data: + options.json_data = {**info.json} + else: + if not is_mapping(options.json_data): + raise TypeError("Pagination is only supported with mappings") + + options.json_data = {**options.json_data, **info.json} + return options + raise ValueError("Unexpected PageInfo state") @@ -408,8 +437,7 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0 headers = httpx.Headers(headers_dict) idempotency_header = self._idempotency_header - if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: - options.idempotency_key = options.idempotency_key or self._idempotency_key() + if idempotency_header and options.idempotency_key and idempotency_header not in headers: headers[idempotency_header] = options.idempotency_key # Don't set these headers if they were already set or removed by the caller. We check @@ -874,7 +902,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[True], stream_cls: Type[_StreamT], @@ -885,7 +912,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: Literal[False] = False, ) -> ResponseT: ... @@ -895,7 +921,6 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: Type[_StreamT] | None = None, @@ -905,125 +930,109 @@ def request( self, cast_to: Type[ResponseT], options: FinalRequestOptions, - remaining_retries: Optional[int] = None, *, stream: bool = False, stream_cls: type[_StreamT] | None = None, ) -> ResponseT | _StreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) + cast_to = self._maybe_override_cast_to(cast_to, options) - def _request( - self, - *, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - retries_taken: int, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) - - cast_to = self._maybe_override_cast_to(cast_to, options) - options = self._prepare_options(options) - - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - self._prepare_request(request) - - if options.idempotency_key: + if input_options.idempotency_key is None and input_options.method.lower() != "get": # ensure the idempotency key is reused between requests - input_options.idempotency_key = options.idempotency_key + input_options.idempotency_key = self._idempotency_key() - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - log.debug("Sending HTTP Request: %s %s", request.method, request.url) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = self._prepare_options(options) - try: - response = self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + self._prepare_request(request) - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - if remaining_retries > 0: - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, + response = None + try: + response = self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err - - log.debug( - 'HTTP Response: %s %s "%i %s" %s', - request.method, - request.url, - response.status_code, - response.reason_phrase, - response.headers, - ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + err.response.close() + self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - err.response.close() - return self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + err.response.read() - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - err.response.read() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return self._process_response( cast_to=cast_to, options=options, @@ -1033,37 +1042,20 @@ def _request( retries_taken=retries_taken, ) - def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_StreamT] | None, - ) -> ResponseT | _StreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) - # In a synchronous context we are blocking the entire thread. Up to the library user to run the client in a - # different thread if necessary. time.sleep(timeout) - return self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - def _process_response( self, *, @@ -1407,7 +1399,6 @@ async def request( options: FinalRequestOptions, *, stream: Literal[False] = False, - remaining_retries: Optional[int] = None, ) -> ResponseT: ... @overload @@ -1418,7 +1409,6 @@ async def request( *, stream: Literal[True], stream_cls: type[_AsyncStreamT], - remaining_retries: Optional[int] = None, ) -> _AsyncStreamT: ... @overload @@ -1429,7 +1419,6 @@ async def request( *, stream: bool, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, ) -> ResponseT | _AsyncStreamT: ... async def request( @@ -1439,120 +1428,111 @@ async def request( *, stream: bool = False, stream_cls: type[_AsyncStreamT] | None = None, - remaining_retries: Optional[int] = None, - ) -> ResponseT | _AsyncStreamT: - if remaining_retries is not None: - retries_taken = options.get_max_retries(self.max_retries) - remaining_retries - else: - retries_taken = 0 - - return await self._request( - cast_to=cast_to, - options=options, - stream=stream, - stream_cls=stream_cls, - retries_taken=retries_taken, - ) - - async def _request( - self, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - *, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - retries_taken: int, ) -> ResponseT | _AsyncStreamT: if self._platform is None: # `get_platform` can make blocking IO calls so we # execute it earlier while we are in an async context self._platform = await asyncify(get_platform)() + cast_to = self._maybe_override_cast_to(cast_to, options) + # create a copy of the options we were given so that if the # options are mutated later & we then retry, the retries are # given the original options input_options = model_copy(options) - - cast_to = self._maybe_override_cast_to(cast_to, options) - options = await self._prepare_options(options) - - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken - request = self._build_request(options, retries_taken=retries_taken) - await self._prepare_request(request) - - if options.idempotency_key: + if input_options.idempotency_key is None and input_options.method.lower() != "get": # ensure the idempotency key is reused between requests - input_options.idempotency_key = options.idempotency_key + input_options.idempotency_key = self._idempotency_key() - kwargs: HttpxSendArgs = {} - if self.custom_auth is not None: - kwargs["auth"] = self.custom_auth + response: httpx.Response | None = None + max_retries = input_options.get_max_retries(self.max_retries) - try: - response = await self._client.send( - request, - stream=stream or self._should_stream_response_body(request=request), - **kwargs, - ) - except httpx.TimeoutException as err: - log.debug("Encountered httpx.TimeoutException", exc_info=True) + retries_taken = 0 + for retries_taken in range(max_retries + 1): + options = model_copy(input_options) + options = await self._prepare_options(options) - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + remaining_retries = max_retries - retries_taken + request = self._build_request(options, retries_taken=retries_taken) + await self._prepare_request(request) - log.debug("Raising timeout error") - raise APITimeoutError(request=request) from err - except Exception as err: - log.debug("Encountered Exception", exc_info=True) + kwargs: HttpxSendArgs = {} + if self.custom_auth is not None: + kwargs["auth"] = self.custom_auth - if remaining_retries > 0: - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - stream=stream, - stream_cls=stream_cls, - response_headers=None, - ) + log.debug("Sending HTTP Request: %s %s", request.method, request.url) - log.debug("Raising connection error") - raise APIConnectionError(request=request) from err + response = None + try: + response = await self._client.send( + request, + stream=stream or self._should_stream_response_body(request=request), + **kwargs, + ) + except httpx.TimeoutException as err: + log.debug("Encountered httpx.TimeoutException", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising timeout error") + raise APITimeoutError(request=request) from err + except Exception as err: + log.debug("Encountered Exception", exc_info=True) + + if remaining_retries > 0: + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=None, + ) + continue + + log.debug("Raising connection error") + raise APIConnectionError(request=request) from err + + log.debug( + 'HTTP Response: %s %s "%i %s" %s', + request.method, + request.url, + response.status_code, + response.reason_phrase, + response.headers, + ) - log.debug( - 'HTTP Request: %s %s "%i %s"', request.method, request.url, response.status_code, response.reason_phrase - ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code + log.debug("Encountered httpx.HTTPStatusError", exc_info=True) + + if remaining_retries > 0 and self._should_retry(err.response): + await err.response.aclose() + await self._sleep_for_retry( + retries_taken=retries_taken, + max_retries=max_retries, + options=input_options, + response=response, + ) + continue - try: - response.raise_for_status() - except httpx.HTTPStatusError as err: # thrown on 4xx and 5xx status code - log.debug("Encountered httpx.HTTPStatusError", exc_info=True) - - if remaining_retries > 0 and self._should_retry(err.response): - await err.response.aclose() - return await self._retry_request( - input_options, - cast_to, - retries_taken=retries_taken, - response_headers=err.response.headers, - stream=stream, - stream_cls=stream_cls, - ) + # If the response is streamed then we need to explicitly read the response + # to completion before attempting to access the response text. + if not err.response.is_closed: + await err.response.aread() - # If the response is streamed then we need to explicitly read the response - # to completion before attempting to access the response text. - if not err.response.is_closed: - await err.response.aread() + log.debug("Re-raising status error") + raise self._make_status_error_from_response(err.response) from None - log.debug("Re-raising status error") - raise self._make_status_error_from_response(err.response) from None + break + assert response is not None, "could not resolve response (should never happen)" return await self._process_response( cast_to=cast_to, options=options, @@ -1562,35 +1542,20 @@ async def _request( retries_taken=retries_taken, ) - async def _retry_request( - self, - options: FinalRequestOptions, - cast_to: Type[ResponseT], - *, - retries_taken: int, - response_headers: httpx.Headers | None, - stream: bool, - stream_cls: type[_AsyncStreamT] | None, - ) -> ResponseT | _AsyncStreamT: - remaining_retries = options.get_max_retries(self.max_retries) - retries_taken + async def _sleep_for_retry( + self, *, retries_taken: int, max_retries: int, options: FinalRequestOptions, response: httpx.Response | None + ) -> None: + remaining_retries = max_retries - retries_taken if remaining_retries == 1: log.debug("1 retry left") else: log.debug("%i retries left", remaining_retries) - timeout = self._calculate_retry_timeout(remaining_retries, options, response_headers) + timeout = self._calculate_retry_timeout(remaining_retries, options, response.headers if response else None) log.info("Retrying request to %s in %f seconds", options.url, timeout) await anyio.sleep(timeout) - return await self._request( - options=options, - cast_to=cast_to, - retries_taken=retries_taken + 1, - stream=stream, - stream_cls=stream_cls, - ) - async def _process_response( self, *, diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 06e86ff..6afa0ec 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -19,10 +19,7 @@ ProxiesTypes, RequestOptions, ) -from ._utils import ( - is_given, - get_async_library, -) +from ._utils import is_given, get_async_library from ._version import __version__ from .resources import accounts, hardware, trainings, collections, predictions from ._streaming import Stream as Stream, AsyncStream as AsyncStream @@ -88,13 +85,13 @@ def __init__( ) -> None: """Construct a new synchronous ReplicateClient client instance. - This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided. + This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN") + bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: raise ReplicateClientError( - "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable" + "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token @@ -270,13 +267,13 @@ def __init__( ) -> None: """Construct a new async AsyncReplicateClient client instance. - This automatically infers the `bearer_token` argument from the `REPLICATE_CLIENT_BEARER_TOKEN` environment variable if it is not provided. + This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_CLIENT_BEARER_TOKEN") + bearer_token = os.environ.get("REPLICATE_API_TOKEN") if bearer_token is None: raise ReplicateClientError( - "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_CLIENT_BEARER_TOKEN environment variable" + "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" ) self.bearer_token = bearer_token diff --git a/src/replicate/_models.py b/src/replicate/_models.py index 3493571..798956f 100644 --- a/src/replicate/_models.py +++ b/src/replicate/_models.py @@ -19,7 +19,6 @@ ) import pydantic -import pydantic.generics from pydantic.fields import FieldInfo from ._types import ( @@ -627,8 +626,8 @@ def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, # Note: if one variant defines an alias then they all should discriminator_alias = field_info.alias - if field_info.annotation and is_literal_type(field_info.annotation): - for entry in get_args(field_info.annotation): + if (annotation := getattr(field_info, "annotation", None)) and is_literal_type(annotation): + for entry in get_args(annotation): if isinstance(entry, str): mapping[entry] = variant diff --git a/src/replicate/_utils/_typing.py b/src/replicate/_utils/_typing.py index 1958820..1bac954 100644 --- a/src/replicate/_utils/_typing.py +++ b/src/replicate/_utils/_typing.py @@ -110,7 +110,7 @@ class MyResponse(Foo[_T]): ``` """ cls = cast(object, get_origin(typ) or typ) - if cls in generic_bases: + if cls in generic_bases: # pyright: ignore[reportUnnecessaryContains] # we're given the class directly return extract_type_arg(typ, index) diff --git a/src/replicate/_utils/_utils.py b/src/replicate/_utils/_utils.py index e5811bb..ea3cf3f 100644 --- a/src/replicate/_utils/_utils.py +++ b/src/replicate/_utils/_utils.py @@ -72,8 +72,16 @@ def _extract_items( from .._files import assert_is_file_content # We have exhausted the path, return the entry we found. - assert_is_file_content(obj, key=flattened_key) assert flattened_key is not None + + if is_list(obj): + files: list[tuple[str, FileTypes]] = [] + for entry in obj: + assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "") + files.append((flattened_key + "[]", cast(FileTypes, entry))) + return files + + assert_is_file_content(obj, key=flattened_key) return [(flattened_key, cast(FileTypes, obj))] index += 1 diff --git a/src/replicate/_version.py b/src/replicate/_version.py index dfeb99c..5b5e6c4 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "0.1.0-alpha.2" # x-release-please-version +__version__ = "0.1.0-alpha.3" # x-release-please-version diff --git a/src/replicate/resources/deployments/deployments.py b/src/replicate/resources/deployments/deployments.py index 0211b67..565a1c6 100644 --- a/src/replicate/resources/deployments/deployments.py +++ b/src/replicate/resources/deployments/deployments.py @@ -6,10 +6,7 @@ from ...types import deployment_create_params, deployment_update_params from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -28,10 +25,10 @@ ) from ...pagination import SyncCursorURLPage, AsyncCursorURLPage from ..._base_client import AsyncPaginator, make_request_options +from ...types.deployment_get_response import DeploymentGetResponse from ...types.deployment_list_response import DeploymentListResponse from ...types.deployment_create_response import DeploymentCreateResponse from ...types.deployment_update_response import DeploymentUpdateResponse -from ...types.deployment_retrieve_response import DeploymentRetrieveResponse __all__ = ["DeploymentsResource", "AsyncDeploymentsResource"] @@ -165,77 +162,6 @@ def create( cast_to=DeploymentCreateResponse, ) - def retrieve( - self, - deployment_name: str, - *, - deployment_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DeploymentRetrieveResponse: - """ - Get information about a deployment by name including the current release. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/deployments/replicate/my-app-image-generator - ``` - - The response will be a JSON object describing the deployment: - - ```json - { - "owner": "acme", - "name": "my-app-image-generator", - "current_release": { - "number": 1, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", - "created_at": "2024-02-15T16:32:57.018467Z", - "created_by": { - "type": "organization", - "username": "acme", - "name": "Acme Corp, Inc.", - "avatar_url": "https://cdn.replicate.com/avatars/acme.png", - "github_url": "https://github.com/acme" - }, - "configuration": { - "hardware": "gpu-t4", - "min_instances": 1, - "max_instances": 5 - } - } - } - ``` - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not deployment_owner: - raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") - if not deployment_name: - raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") - return self._get( - f"/deployments/{deployment_owner}/{deployment_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=DeploymentRetrieveResponse, - ) - def update( self, deployment_name: str, @@ -454,6 +380,77 @@ def delete( cast_to=NoneType, ) + def get( + self, + deployment_name: str, + *, + deployment_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DeploymentGetResponse: + """ + Get information about a deployment by name including the current release. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/deployments/replicate/my-app-image-generator + ``` + + The response will be a JSON object describing the deployment: + + ```json + { + "owner": "acme", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "avatar_url": "https://cdn.replicate.com/avatars/acme.png", + "github_url": "https://github.com/acme" + }, + "configuration": { + "hardware": "gpu-t4", + "min_instances": 1, + "max_instances": 5 + } + } + } + ``` + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not deployment_owner: + raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") + if not deployment_name: + raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") + return self._get( + f"/deployments/{deployment_owner}/{deployment_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=DeploymentGetResponse, + ) + def list_em_all( self, *, @@ -628,77 +625,6 @@ async def create( cast_to=DeploymentCreateResponse, ) - async def retrieve( - self, - deployment_name: str, - *, - deployment_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> DeploymentRetrieveResponse: - """ - Get information about a deployment by name including the current release. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/deployments/replicate/my-app-image-generator - ``` - - The response will be a JSON object describing the deployment: - - ```json - { - "owner": "acme", - "name": "my-app-image-generator", - "current_release": { - "number": 1, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", - "created_at": "2024-02-15T16:32:57.018467Z", - "created_by": { - "type": "organization", - "username": "acme", - "name": "Acme Corp, Inc.", - "avatar_url": "https://cdn.replicate.com/avatars/acme.png", - "github_url": "https://github.com/acme" - }, - "configuration": { - "hardware": "gpu-t4", - "min_instances": 1, - "max_instances": 5 - } - } - } - ``` - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not deployment_owner: - raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") - if not deployment_name: - raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") - return await self._get( - f"/deployments/{deployment_owner}/{deployment_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=DeploymentRetrieveResponse, - ) - async def update( self, deployment_name: str, @@ -917,6 +843,77 @@ async def delete( cast_to=NoneType, ) + async def get( + self, + deployment_name: str, + *, + deployment_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> DeploymentGetResponse: + """ + Get information about a deployment by name including the current release. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/deployments/replicate/my-app-image-generator + ``` + + The response will be a JSON object describing the deployment: + + ```json + { + "owner": "acme", + "name": "my-app-image-generator", + "current_release": { + "number": 1, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + "created_at": "2024-02-15T16:32:57.018467Z", + "created_by": { + "type": "organization", + "username": "acme", + "name": "Acme Corp, Inc.", + "avatar_url": "https://cdn.replicate.com/avatars/acme.png", + "github_url": "https://github.com/acme" + }, + "configuration": { + "hardware": "gpu-t4", + "min_instances": 1, + "max_instances": 5 + } + } + } + ``` + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not deployment_owner: + raise ValueError(f"Expected a non-empty value for `deployment_owner` but received {deployment_owner!r}") + if not deployment_name: + raise ValueError(f"Expected a non-empty value for `deployment_name` but received {deployment_name!r}") + return await self._get( + f"/deployments/{deployment_owner}/{deployment_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=DeploymentGetResponse, + ) + async def list_em_all( self, *, @@ -969,9 +966,6 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.create = to_raw_response_wrapper( deployments.create, ) - self.retrieve = to_raw_response_wrapper( - deployments.retrieve, - ) self.update = to_raw_response_wrapper( deployments.update, ) @@ -981,6 +975,9 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.delete = to_raw_response_wrapper( deployments.delete, ) + self.get = to_raw_response_wrapper( + deployments.get, + ) self.list_em_all = to_raw_response_wrapper( deployments.list_em_all, ) @@ -997,9 +994,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.create = async_to_raw_response_wrapper( deployments.create, ) - self.retrieve = async_to_raw_response_wrapper( - deployments.retrieve, - ) self.update = async_to_raw_response_wrapper( deployments.update, ) @@ -1009,6 +1003,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.delete = async_to_raw_response_wrapper( deployments.delete, ) + self.get = async_to_raw_response_wrapper( + deployments.get, + ) self.list_em_all = async_to_raw_response_wrapper( deployments.list_em_all, ) @@ -1025,9 +1022,6 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.create = to_streamed_response_wrapper( deployments.create, ) - self.retrieve = to_streamed_response_wrapper( - deployments.retrieve, - ) self.update = to_streamed_response_wrapper( deployments.update, ) @@ -1037,6 +1031,9 @@ def __init__(self, deployments: DeploymentsResource) -> None: self.delete = to_streamed_response_wrapper( deployments.delete, ) + self.get = to_streamed_response_wrapper( + deployments.get, + ) self.list_em_all = to_streamed_response_wrapper( deployments.list_em_all, ) @@ -1053,9 +1050,6 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.create = async_to_streamed_response_wrapper( deployments.create, ) - self.retrieve = async_to_streamed_response_wrapper( - deployments.retrieve, - ) self.update = async_to_streamed_response_wrapper( deployments.update, ) @@ -1065,6 +1059,9 @@ def __init__(self, deployments: AsyncDeploymentsResource) -> None: self.delete = async_to_streamed_response_wrapper( deployments.delete, ) + self.get = async_to_streamed_response_wrapper( + deployments.get, + ) self.list_em_all = async_to_streamed_response_wrapper( deployments.list_em_all, ) diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py index 7177332..c94fff0 100644 --- a/src/replicate/resources/deployments/predictions.py +++ b/src/replicate/resources/deployments/predictions.py @@ -8,11 +8,7 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from ..._utils import maybe_transform, strip_not_given, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( diff --git a/src/replicate/resources/models/models.py b/src/replicate/resources/models/models.py index 77145ae..435101f 100644 --- a/src/replicate/resources/models/models.py +++ b/src/replicate/resources/models/models.py @@ -9,11 +9,7 @@ from ...types import model_create_params, model_create_prediction_params from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from ..._utils import maybe_transform, strip_not_given, async_maybe_transform from .versions import ( VersionsResource, AsyncVersionsResource, @@ -174,115 +170,6 @@ def create( cast_to=NoneType, ) - def retrieve( - self, - model_name: str, - *, - model_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world - ``` - - The response will be a model object in the following format: - - ```json - { - "url": "https://replicate.com/replicate/hello-world", - "owner": "replicate", - "name": "hello-world", - "description": "A tiny model that says hello", - "visibility": "public", - "github_url": "https://github.com/replicate/cog-examples", - "paper_url": null, - "license_url": null, - "run_count": 5681081, - "cover_image_url": "...", - "default_example": {...}, - "latest_version": {...}, - } - ``` - - The model object includes the - [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) - for the latest version of the model. - - Here's an example showing how to fetch the model with cURL and display its input - schema with [jq](https://stedolan.github.io/jq/): - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world \\ - | jq ".latest_version.openapi_schema.components.schemas.Input" - ``` - - This will return the following JSON object: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `cover_image_url` string is an HTTPS URL for an image file. This can be: - - - An image uploaded by the model author. - - The output file of the example prediction, if the model author has not set a - cover image. - - The input file of the example prediction, if the model author has not set a - cover image and the example prediction has no output file. - - A generic fallback image. - - The `default_example` object is a [prediction](#predictions.get) created with - this model. - - The `latest_version` object is the model's most recently pushed - [version](#models.versions.get). - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( - f"/models/{model_owner}/{model_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, *, @@ -514,6 +401,115 @@ def create_prediction( cast_to=Prediction, ) + def get( + self, + model_name: str, + *, + model_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world + ``` + + The response will be a model object in the following format: + + ```json + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": {...}, + "latest_version": {...}, + } + ``` + + The model object includes the + [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) + for the latest version of the model. + + Here's an example showing how to fetch the model with cURL and display its input + schema with [jq](https://stedolan.github.io/jq/): + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world \\ + | jq ".latest_version.openapi_schema.components.schemas.Input" + ``` + + This will return the following JSON object: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `cover_image_url` string is an HTTPS URL for an image file. This can be: + + - An image uploaded by the model author. + - The output file of the example prediction, if the model author has not set a + cover image. + - The input file of the example prediction, if the model author has not set a + cover image and the example prediction has no output file. + - A generic fallback image. + + The `default_example` object is a [prediction](#predictions.get) created with + this model. + + The `latest_version` object is the model's most recently pushed + [version](#models.versions.get). + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return self._get( + f"/models/{model_owner}/{model_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class AsyncModelsResource(AsyncAPIResource): @cached_property @@ -651,115 +647,6 @@ async def create( cast_to=NoneType, ) - async def retrieve( - self, - model_name: str, - *, - model_owner: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world - ``` - - The response will be a model object in the following format: - - ```json - { - "url": "https://replicate.com/replicate/hello-world", - "owner": "replicate", - "name": "hello-world", - "description": "A tiny model that says hello", - "visibility": "public", - "github_url": "https://github.com/replicate/cog-examples", - "paper_url": null, - "license_url": null, - "run_count": 5681081, - "cover_image_url": "...", - "default_example": {...}, - "latest_version": {...}, - } - ``` - - The model object includes the - [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) - for the latest version of the model. - - Here's an example showing how to fetch the model with cURL and display its input - schema with [jq](https://stedolan.github.io/jq/): - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world \\ - | jq ".latest_version.openapi_schema.components.schemas.Input" - ``` - - This will return the following JSON object: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "type": "string", - "title": "Text", - "x-order": 0, - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `cover_image_url` string is an HTTPS URL for an image file. This can be: - - - An image uploaded by the model author. - - The output file of the example prediction, if the model author has not set a - cover image. - - The input file of the example prediction, if the model author has not set a - cover image and the example prediction has no output file. - - A generic fallback image. - - The `default_example` object is a [prediction](#predictions.get) created with - this model. - - The `latest_version` object is the model's most recently pushed - [version](#models.versions.get). - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( - f"/models/{model_owner}/{model_name}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, *, @@ -991,6 +878,115 @@ async def create_prediction( cast_to=Prediction, ) + async def get( + self, + model_name: str, + *, + model_owner: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world + ``` + + The response will be a model object in the following format: + + ```json + { + "url": "https://replicate.com/replicate/hello-world", + "owner": "replicate", + "name": "hello-world", + "description": "A tiny model that says hello", + "visibility": "public", + "github_url": "https://github.com/replicate/cog-examples", + "paper_url": null, + "license_url": null, + "run_count": 5681081, + "cover_image_url": "...", + "default_example": {...}, + "latest_version": {...}, + } + ``` + + The model object includes the + [input and output schema](https://replicate.com/docs/reference/openapi#model-schemas) + for the latest version of the model. + + Here's an example showing how to fetch the model with cURL and display its input + schema with [jq](https://stedolan.github.io/jq/): + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world \\ + | jq ".latest_version.openapi_schema.components.schemas.Input" + ``` + + This will return the following JSON object: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "type": "string", + "title": "Text", + "x-order": 0, + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `cover_image_url` string is an HTTPS URL for an image file. This can be: + + - An image uploaded by the model author. + - The output file of the example prediction, if the model author has not set a + cover image. + - The input file of the example prediction, if the model author has not set a + cover image and the example prediction has no output file. + - A generic fallback image. + + The `default_example` object is a [prediction](#predictions.get) created with + this model. + + The `latest_version` object is the model's most recently pushed + [version](#models.versions.get). + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return await self._get( + f"/models/{model_owner}/{model_name}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + class ModelsResourceWithRawResponse: def __init__(self, models: ModelsResource) -> None: @@ -999,9 +995,6 @@ def __init__(self, models: ModelsResource) -> None: self.create = to_raw_response_wrapper( models.create, ) - self.retrieve = to_raw_response_wrapper( - models.retrieve, - ) self.list = to_raw_response_wrapper( models.list, ) @@ -1011,6 +1004,9 @@ def __init__(self, models: ModelsResource) -> None: self.create_prediction = to_raw_response_wrapper( models.create_prediction, ) + self.get = to_raw_response_wrapper( + models.get, + ) @cached_property def versions(self) -> VersionsResourceWithRawResponse: @@ -1024,9 +1020,6 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create = async_to_raw_response_wrapper( models.create, ) - self.retrieve = async_to_raw_response_wrapper( - models.retrieve, - ) self.list = async_to_raw_response_wrapper( models.list, ) @@ -1036,6 +1029,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create_prediction = async_to_raw_response_wrapper( models.create_prediction, ) + self.get = async_to_raw_response_wrapper( + models.get, + ) @cached_property def versions(self) -> AsyncVersionsResourceWithRawResponse: @@ -1049,9 +1045,6 @@ def __init__(self, models: ModelsResource) -> None: self.create = to_streamed_response_wrapper( models.create, ) - self.retrieve = to_streamed_response_wrapper( - models.retrieve, - ) self.list = to_streamed_response_wrapper( models.list, ) @@ -1061,6 +1054,9 @@ def __init__(self, models: ModelsResource) -> None: self.create_prediction = to_streamed_response_wrapper( models.create_prediction, ) + self.get = to_streamed_response_wrapper( + models.get, + ) @cached_property def versions(self) -> VersionsResourceWithStreamingResponse: @@ -1074,9 +1070,6 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create = async_to_streamed_response_wrapper( models.create, ) - self.retrieve = async_to_streamed_response_wrapper( - models.retrieve, - ) self.list = async_to_streamed_response_wrapper( models.list, ) @@ -1086,6 +1079,9 @@ def __init__(self, models: AsyncModelsResource) -> None: self.create_prediction = async_to_streamed_response_wrapper( models.create_prediction, ) + self.get = async_to_streamed_response_wrapper( + models.get, + ) @cached_property def versions(self) -> AsyncVersionsResourceWithStreamingResponse: diff --git a/src/replicate/resources/models/versions.py b/src/replicate/resources/models/versions.py index eb411ef..308d04b 100644 --- a/src/replicate/resources/models/versions.py +++ b/src/replicate/resources/models/versions.py @@ -8,10 +8,7 @@ import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from ..._utils import ( - maybe_transform, - async_maybe_transform, -) +from ..._utils import maybe_transform, async_maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -22,6 +19,7 @@ ) from ..._base_client import make_request_options from ...types.models import version_create_training_params +from ...types.models.version_create_training_response import VersionCreateTrainingResponse __all__ = ["VersionsResource", "AsyncVersionsResource"] @@ -46,101 +44,6 @@ def with_streaming_response(self) -> VersionsResourceWithStreamingResponse: """ return VersionsResourceWithStreamingResponse(self) - def retrieve( - self, - version_id: str, - *, - model_owner: str, - model_name: str, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa - ``` - - The response will be the version object: - - ```json - { - "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "created_at": "2022-04-26T19:29:04.418669Z", - "cog_version": "0.3.0", - "openapi_schema": {...} - } - ``` - - Every model describes its inputs and outputs with - [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject) - in the `openapi_schema` property. - - The `openapi_schema.components.schemas.Input` property for the - [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks - like this: - - ```json - { - "type": "object", - "title": "Input", - "required": ["text"], - "properties": { - "text": { - "x-order": 0, - "type": "string", - "title": "Text", - "description": "Text to prefix with 'hello '" - } - } - } - ``` - - The `openapi_schema.components.schemas.Output` property for the - [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks - like this: - - ```json - { - "type": "string", - "title": "Output" - } - ``` - - For more details, see the docs on - [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types) - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not model_owner: - raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") - if not model_name: - raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") - if not version_id: - raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( - f"/models/{model_owner}/{model_name}/versions/{version_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, model_name: str, @@ -281,7 +184,7 @@ def create_training( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> VersionCreateTrainingResponse: """ Start a new training of the model version you specify. @@ -398,7 +301,6 @@ def create_training( raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings", body=maybe_transform( @@ -413,31 +315,10 @@ def create_training( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=VersionCreateTrainingResponse, ) - -class AsyncVersionsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers - """ - return AsyncVersionsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response - """ - return AsyncVersionsResourceWithStreamingResponse(self) - - async def retrieve( + def get( self, version_id: str, *, @@ -524,7 +405,7 @@ async def retrieve( if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( + return self._get( f"/models/{model_owner}/{model_name}/versions/{version_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout @@ -532,6 +413,27 @@ async def retrieve( cast_to=NoneType, ) + +class AsyncVersionsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncVersionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers + """ + return AsyncVersionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncVersionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response + """ + return AsyncVersionsResourceWithStreamingResponse(self) + async def list( self, model_name: str, @@ -672,7 +574,7 @@ async def create_training( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> VersionCreateTrainingResponse: """ Start a new training of the model version you specify. @@ -789,7 +691,6 @@ async def create_training( raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") if not version_id: raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( f"/models/{model_owner}/{model_name}/versions/{version_id}/trainings", body=await async_maybe_transform( @@ -804,6 +705,101 @@ async def create_training( options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), + cast_to=VersionCreateTrainingResponse, + ) + + async def get( + self, + version_id: str, + *, + model_owner: str, + model_name: str, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/models/replicate/hello-world/versions/5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa + ``` + + The response will be the version object: + + ```json + { + "id": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "created_at": "2022-04-26T19:29:04.418669Z", + "cog_version": "0.3.0", + "openapi_schema": {...} + } + ``` + + Every model describes its inputs and outputs with + [OpenAPI Schema Objects](https://spec.openapis.org/oas/latest.html#schemaObject) + in the `openapi_schema` property. + + The `openapi_schema.components.schemas.Input` property for the + [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks + like this: + + ```json + { + "type": "object", + "title": "Input", + "required": ["text"], + "properties": { + "text": { + "x-order": 0, + "type": "string", + "title": "Text", + "description": "Text to prefix with 'hello '" + } + } + } + ``` + + The `openapi_schema.components.schemas.Output` property for the + [replicate/hello-world](https://replicate.com/replicate/hello-world) model looks + like this: + + ```json + { + "type": "string", + "title": "Output" + } + ``` + + For more details, see the docs on + [Cog's supported input and output types](https://github.com/replicate/cog/blob/75b7802219e7cd4cee845e34c4c22139558615d4/docs/python.md#input-and-output-types) + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not model_owner: + raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}") + if not model_name: + raise ValueError(f"Expected a non-empty value for `model_name` but received {model_name!r}") + if not version_id: + raise ValueError(f"Expected a non-empty value for `version_id` but received {version_id!r}") + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + return await self._get( + f"/models/{model_owner}/{model_name}/versions/{version_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), cast_to=NoneType, ) @@ -812,9 +808,6 @@ class VersionsResourceWithRawResponse: def __init__(self, versions: VersionsResource) -> None: self._versions = versions - self.retrieve = to_raw_response_wrapper( - versions.retrieve, - ) self.list = to_raw_response_wrapper( versions.list, ) @@ -824,15 +817,15 @@ def __init__(self, versions: VersionsResource) -> None: self.create_training = to_raw_response_wrapper( versions.create_training, ) + self.get = to_raw_response_wrapper( + versions.get, + ) class AsyncVersionsResourceWithRawResponse: def __init__(self, versions: AsyncVersionsResource) -> None: self._versions = versions - self.retrieve = async_to_raw_response_wrapper( - versions.retrieve, - ) self.list = async_to_raw_response_wrapper( versions.list, ) @@ -842,15 +835,15 @@ def __init__(self, versions: AsyncVersionsResource) -> None: self.create_training = async_to_raw_response_wrapper( versions.create_training, ) + self.get = async_to_raw_response_wrapper( + versions.get, + ) class VersionsResourceWithStreamingResponse: def __init__(self, versions: VersionsResource) -> None: self._versions = versions - self.retrieve = to_streamed_response_wrapper( - versions.retrieve, - ) self.list = to_streamed_response_wrapper( versions.list, ) @@ -860,15 +853,15 @@ def __init__(self, versions: VersionsResource) -> None: self.create_training = to_streamed_response_wrapper( versions.create_training, ) + self.get = to_streamed_response_wrapper( + versions.get, + ) class AsyncVersionsResourceWithStreamingResponse: def __init__(self, versions: AsyncVersionsResource) -> None: self._versions = versions - self.retrieve = async_to_streamed_response_wrapper( - versions.retrieve, - ) self.list = async_to_streamed_response_wrapper( versions.list, ) @@ -878,3 +871,6 @@ def __init__(self, versions: AsyncVersionsResource) -> None: self.create_training = async_to_streamed_response_wrapper( versions.create_training, ) + self.get = async_to_streamed_response_wrapper( + versions.get, + ) diff --git a/src/replicate/resources/predictions.py b/src/replicate/resources/predictions.py index d6a1245..abe0ca1 100644 --- a/src/replicate/resources/predictions.py +++ b/src/replicate/resources/predictions.py @@ -10,11 +10,7 @@ from ..types import prediction_list_params, prediction_create_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven -from .._utils import ( - maybe_transform, - strip_not_given, - async_maybe_transform, -) +from .._utils import maybe_transform, strip_not_given, async_maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -189,108 +185,6 @@ def create( cast_to=Prediction, ) - def retrieve( - self, - prediction_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Prediction: - """ - Get the current state of a prediction. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu - ``` - - The response will be the prediction object: - - ```json - { - "id": "gm3qorzdhgbfurvjtvhg6dckhu", - "model": "replicate/hello-world", - "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "input": { - "text": "Alice" - }, - "logs": "", - "output": "hello Alice", - "error": null, - "status": "succeeded", - "created_at": "2023-09-08T16:19:34.765994Z", - "data_removed": false, - "started_at": "2023-09-08T16:19:34.779176Z", - "completed_at": "2023-09-08T16:19:34.791859Z", - "metrics": { - "predict_time": 0.012683 - }, - "urls": { - "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", - "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" - } - } - ``` - - `status` will be one of: - - - `starting`: the prediction is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the prediction. - - `processing`: the `predict()` method of the model is currently running. - - `succeeded`: the prediction completed successfully. - - `failed`: the prediction encountered an error during processing. - - `canceled`: the prediction was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - prediction. - - Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the prediction used while running. - It won't include time waiting for the prediction to start. - - All input parameters, output values, and logs are automatically removed after an - hour, by default, for predictions created through the API. - - You must save a copy of any data or files in the output if you'd like to - continue using them. The `output` key will still be present, but it's value will - be `null` after the output has been removed. - - Output files are served by `replicate.delivery` and its subdomains. If you use - an allow list of external domains for your assets, add `replicate.delivery` and - `*.replicate.delivery` to it. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not prediction_id: - raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") - return self._get( - f"/predictions/{prediction_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Prediction, - ) - def list( self, *, @@ -440,6 +334,108 @@ def cancel( cast_to=NoneType, ) + def get( + self, + prediction_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Prediction: + """ + Get the current state of a prediction. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu + ``` + + The response will be the prediction object: + + ```json + { + "id": "gm3qorzdhgbfurvjtvhg6dckhu", + "model": "replicate/hello-world", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "input": { + "text": "Alice" + }, + "logs": "", + "output": "hello Alice", + "error": null, + "status": "succeeded", + "created_at": "2023-09-08T16:19:34.765994Z", + "data_removed": false, + "started_at": "2023-09-08T16:19:34.779176Z", + "completed_at": "2023-09-08T16:19:34.791859Z", + "metrics": { + "predict_time": 0.012683 + }, + "urls": { + "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", + "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" + } + } + ``` + + `status` will be one of: + + - `starting`: the prediction is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the prediction. + - `processing`: the `predict()` method of the model is currently running. + - `succeeded`: the prediction completed successfully. + - `failed`: the prediction encountered an error during processing. + - `canceled`: the prediction was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + prediction. + + Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the prediction used while running. + It won't include time waiting for the prediction to start. + + All input parameters, output values, and logs are automatically removed after an + hour, by default, for predictions created through the API. + + You must save a copy of any data or files in the output if you'd like to + continue using them. The `output` key will still be present, but it's value will + be `null` after the output has been removed. + + Output files are served by `replicate.delivery` and its subdomains. If you use + an allow list of external domains for your assets, add `replicate.delivery` and + `*.replicate.delivery` to it. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not prediction_id: + raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") + return self._get( + f"/predictions/{prediction_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Prediction, + ) + class AsyncPredictionsResource(AsyncAPIResource): @cached_property @@ -600,108 +596,6 @@ async def create( cast_to=Prediction, ) - async def retrieve( - self, - prediction_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> Prediction: - """ - Get the current state of a prediction. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu - ``` - - The response will be the prediction object: - - ```json - { - "id": "gm3qorzdhgbfurvjtvhg6dckhu", - "model": "replicate/hello-world", - "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", - "input": { - "text": "Alice" - }, - "logs": "", - "output": "hello Alice", - "error": null, - "status": "succeeded", - "created_at": "2023-09-08T16:19:34.765994Z", - "data_removed": false, - "started_at": "2023-09-08T16:19:34.779176Z", - "completed_at": "2023-09-08T16:19:34.791859Z", - "metrics": { - "predict_time": 0.012683 - }, - "urls": { - "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", - "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" - } - } - ``` - - `status` will be one of: - - - `starting`: the prediction is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the prediction. - - `processing`: the `predict()` method of the model is currently running. - - `succeeded`: the prediction completed successfully. - - `failed`: the prediction encountered an error during processing. - - `canceled`: the prediction was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - prediction. - - Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the prediction used while running. - It won't include time waiting for the prediction to start. - - All input parameters, output values, and logs are automatically removed after an - hour, by default, for predictions created through the API. - - You must save a copy of any data or files in the output if you'd like to - continue using them. The `output` key will still be present, but it's value will - be `null` after the output has been removed. - - Output files are served by `replicate.delivery` and its subdomains. If you use - an allow list of external domains for your assets, add `replicate.delivery` and - `*.replicate.delivery` to it. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not prediction_id: - raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") - return await self._get( - f"/predictions/{prediction_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=Prediction, - ) - def list( self, *, @@ -851,6 +745,108 @@ async def cancel( cast_to=NoneType, ) + async def get( + self, + prediction_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Prediction: + """ + Get the current state of a prediction. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu + ``` + + The response will be the prediction object: + + ```json + { + "id": "gm3qorzdhgbfurvjtvhg6dckhu", + "model": "replicate/hello-world", + "version": "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa", + "input": { + "text": "Alice" + }, + "logs": "", + "output": "hello Alice", + "error": null, + "status": "succeeded", + "created_at": "2023-09-08T16:19:34.765994Z", + "data_removed": false, + "started_at": "2023-09-08T16:19:34.779176Z", + "completed_at": "2023-09-08T16:19:34.791859Z", + "metrics": { + "predict_time": 0.012683 + }, + "urls": { + "cancel": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu/cancel", + "get": "https://api.replicate.com/v1/predictions/gm3qorzdhgbfurvjtvhg6dckhu" + } + } + ``` + + `status` will be one of: + + - `starting`: the prediction is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the prediction. + - `processing`: the `predict()` method of the model is currently running. + - `succeeded`: the prediction completed successfully. + - `failed`: the prediction encountered an error during processing. + - `canceled`: the prediction was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + prediction. + + Terminated predictions (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the prediction used while running. + It won't include time waiting for the prediction to start. + + All input parameters, output values, and logs are automatically removed after an + hour, by default, for predictions created through the API. + + You must save a copy of any data or files in the output if you'd like to + continue using them. The `output` key will still be present, but it's value will + be `null` after the output has been removed. + + Output files are served by `replicate.delivery` and its subdomains. If you use + an allow list of external domains for your assets, add `replicate.delivery` and + `*.replicate.delivery` to it. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not prediction_id: + raise ValueError(f"Expected a non-empty value for `prediction_id` but received {prediction_id!r}") + return await self._get( + f"/predictions/{prediction_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Prediction, + ) + class PredictionsResourceWithRawResponse: def __init__(self, predictions: PredictionsResource) -> None: @@ -859,15 +855,15 @@ def __init__(self, predictions: PredictionsResource) -> None: self.create = to_raw_response_wrapper( predictions.create, ) - self.retrieve = to_raw_response_wrapper( - predictions.retrieve, - ) self.list = to_raw_response_wrapper( predictions.list, ) self.cancel = to_raw_response_wrapper( predictions.cancel, ) + self.get = to_raw_response_wrapper( + predictions.get, + ) class AsyncPredictionsResourceWithRawResponse: @@ -877,15 +873,15 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None: self.create = async_to_raw_response_wrapper( predictions.create, ) - self.retrieve = async_to_raw_response_wrapper( - predictions.retrieve, - ) self.list = async_to_raw_response_wrapper( predictions.list, ) self.cancel = async_to_raw_response_wrapper( predictions.cancel, ) + self.get = async_to_raw_response_wrapper( + predictions.get, + ) class PredictionsResourceWithStreamingResponse: @@ -895,15 +891,15 @@ def __init__(self, predictions: PredictionsResource) -> None: self.create = to_streamed_response_wrapper( predictions.create, ) - self.retrieve = to_streamed_response_wrapper( - predictions.retrieve, - ) self.list = to_streamed_response_wrapper( predictions.list, ) self.cancel = to_streamed_response_wrapper( predictions.cancel, ) + self.get = to_streamed_response_wrapper( + predictions.get, + ) class AsyncPredictionsResourceWithStreamingResponse: @@ -913,12 +909,12 @@ def __init__(self, predictions: AsyncPredictionsResource) -> None: self.create = async_to_streamed_response_wrapper( predictions.create, ) - self.retrieve = async_to_streamed_response_wrapper( - predictions.retrieve, - ) self.list = async_to_streamed_response_wrapper( predictions.list, ) self.cancel = async_to_streamed_response_wrapper( predictions.cancel, ) + self.get = async_to_streamed_response_wrapper( + predictions.get, + ) diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py index df64c19..5a357af 100644 --- a/src/replicate/resources/trainings.py +++ b/src/replicate/resources/trainings.py @@ -4,7 +4,7 @@ import httpx -from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource from .._response import ( @@ -13,7 +13,11 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) -from .._base_client import make_request_options +from ..pagination import SyncCursorURLPage, AsyncCursorURLPage +from .._base_client import AsyncPaginator, make_request_options +from ..types.training_get_response import TrainingGetResponse +from ..types.training_list_response import TrainingListResponse +from ..types.training_cancel_response import TrainingCancelResponse __all__ = ["TrainingsResource", "AsyncTrainingsResource"] @@ -38,100 +42,6 @@ def with_streaming_response(self) -> TrainingsResourceWithStreamingResponse: """ return TrainingsResourceWithStreamingResponse(self) - def retrieve( - self, - training_id: str, - *, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: - """ - Get the current state of a training. - - Example cURL request: - - ```console - curl -s \\ - -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ - https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga - ``` - - The response will be the training object: - - ```json - { - "completed_at": "2023-09-08T16:41:19.826523Z", - "created_at": "2023-09-08T16:32:57.018467Z", - "error": null, - "id": "zz4ibbonubfz7carwiefibzgga", - "input": { - "input_images": "https://example.com/my-input-images.zip" - }, - "logs": "...", - "metrics": { - "predict_time": 502.713876 - }, - "output": { - "version": "...", - "weights": "..." - }, - "started_at": "2023-09-08T16:32:57.112647Z", - "status": "succeeded", - "urls": { - "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", - "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel" - }, - "model": "stability-ai/sdxl", - "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" - } - ``` - - `status` will be one of: - - - `starting`: the training is starting up. If this status lasts longer than a - few seconds, then it's typically because a new worker is being started to run - the training. - - `processing`: the `train()` method of the model is currently running. - - `succeeded`: the training completed successfully. - - `failed`: the training encountered an error during processing. - - `canceled`: the training was canceled by its creator. - - In the case of success, `output` will be an object containing the output of the - model. Any files will be represented as HTTPS URLs. You'll need to pass the - `Authorization` header to request them. - - In the case of failure, `error` will contain the error encountered during the - training. - - Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`) - will include a `metrics` object with a `predict_time` property showing the - amount of CPU or GPU time, in seconds, that the training used while running. It - won't include time waiting for the training to start. - - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - if not training_id: - raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( - f"/trainings/{training_id}", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=NoneType, - ) - def list( self, *, @@ -141,7 +51,7 @@ def list( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> SyncCursorURLPage[TrainingListResponse]: """ Get a paginated list of all trainings created by the user or organization associated with the provided API token. @@ -207,13 +117,13 @@ def list( `version` will be the unique ID of model version used to create the training. """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return self._get( + return self._get_api_list( "/trainings", + page=SyncCursorURLPage[TrainingListResponse], options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=TrainingListResponse, ) def cancel( @@ -226,7 +136,7 @@ def cancel( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingCancelResponse: """ Cancel a training @@ -241,37 +151,15 @@ def cancel( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return self._post( f"/trainings/{training_id}/cancel", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingCancelResponse, ) - -class AsyncTrainingsResource(AsyncAPIResource): - @cached_property - def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse: - """ - This property can be used as a prefix for any HTTP method call to return - the raw response object instead of the parsed content. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers - """ - return AsyncTrainingsResourceWithRawResponse(self) - - @cached_property - def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse: - """ - An alternative to `.with_raw_response` that doesn't eagerly read the response body. - - For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response - """ - return AsyncTrainingsResourceWithStreamingResponse(self) - - async def retrieve( + def get( self, training_id: str, *, @@ -281,7 +169,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingGetResponse: """ Get the current state of a training. @@ -356,16 +244,36 @@ async def retrieve( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( + return self._get( f"/trainings/{training_id}", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingGetResponse, ) - async def list( + +class AsyncTrainingsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncTrainingsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#accessing-raw-response-data-eg-headers + """ + return AsyncTrainingsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncTrainingsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/replicate/replicate-python-stainless#with_streaming_response + """ + return AsyncTrainingsResourceWithStreamingResponse(self) + + def list( self, *, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -374,7 +282,7 @@ async def list( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> AsyncPaginator[TrainingListResponse, AsyncCursorURLPage[TrainingListResponse]]: """ Get a paginated list of all trainings created by the user or organization associated with the provided API token. @@ -440,13 +348,13 @@ async def list( `version` will be the unique ID of model version used to create the training. """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} - return await self._get( + return self._get_api_list( "/trainings", + page=AsyncCursorURLPage[TrainingListResponse], options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + model=TrainingListResponse, ) async def cancel( @@ -459,7 +367,7 @@ async def cancel( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> TrainingCancelResponse: """ Cancel a training @@ -474,13 +382,105 @@ async def cancel( """ if not training_id: raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") - extra_headers = {"Accept": "*/*", **(extra_headers or {})} return await self._post( f"/trainings/{training_id}/cancel", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=NoneType, + cast_to=TrainingCancelResponse, + ) + + async def get( + self, + training_id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> TrainingGetResponse: + """ + Get the current state of a training. + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga + ``` + + The response will be the training object: + + ```json + { + "completed_at": "2023-09-08T16:41:19.826523Z", + "created_at": "2023-09-08T16:32:57.018467Z", + "error": null, + "id": "zz4ibbonubfz7carwiefibzgga", + "input": { + "input_images": "https://example.com/my-input-images.zip" + }, + "logs": "...", + "metrics": { + "predict_time": 502.713876 + }, + "output": { + "version": "...", + "weights": "..." + }, + "started_at": "2023-09-08T16:32:57.112647Z", + "status": "succeeded", + "urls": { + "get": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga", + "cancel": "https://api.replicate.com/v1/trainings/zz4ibbonubfz7carwiefibzgga/cancel" + }, + "model": "stability-ai/sdxl", + "version": "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf" + } + ``` + + `status` will be one of: + + - `starting`: the training is starting up. If this status lasts longer than a + few seconds, then it's typically because a new worker is being started to run + the training. + - `processing`: the `train()` method of the model is currently running. + - `succeeded`: the training completed successfully. + - `failed`: the training encountered an error during processing. + - `canceled`: the training was canceled by its creator. + + In the case of success, `output` will be an object containing the output of the + model. Any files will be represented as HTTPS URLs. You'll need to pass the + `Authorization` header to request them. + + In the case of failure, `error` will contain the error encountered during the + training. + + Terminated trainings (with a status of `succeeded`, `failed`, or `canceled`) + will include a `metrics` object with a `predict_time` property showing the + amount of CPU or GPU time, in seconds, that the training used while running. It + won't include time waiting for the training to start. + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not training_id: + raise ValueError(f"Expected a non-empty value for `training_id` but received {training_id!r}") + return await self._get( + f"/trainings/{training_id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=TrainingGetResponse, ) @@ -488,57 +488,57 @@ class TrainingsResourceWithRawResponse: def __init__(self, trainings: TrainingsResource) -> None: self._trainings = trainings - self.retrieve = to_raw_response_wrapper( - trainings.retrieve, - ) self.list = to_raw_response_wrapper( trainings.list, ) self.cancel = to_raw_response_wrapper( trainings.cancel, ) + self.get = to_raw_response_wrapper( + trainings.get, + ) class AsyncTrainingsResourceWithRawResponse: def __init__(self, trainings: AsyncTrainingsResource) -> None: self._trainings = trainings - self.retrieve = async_to_raw_response_wrapper( - trainings.retrieve, - ) self.list = async_to_raw_response_wrapper( trainings.list, ) self.cancel = async_to_raw_response_wrapper( trainings.cancel, ) + self.get = async_to_raw_response_wrapper( + trainings.get, + ) class TrainingsResourceWithStreamingResponse: def __init__(self, trainings: TrainingsResource) -> None: self._trainings = trainings - self.retrieve = to_streamed_response_wrapper( - trainings.retrieve, - ) self.list = to_streamed_response_wrapper( trainings.list, ) self.cancel = to_streamed_response_wrapper( trainings.cancel, ) + self.get = to_streamed_response_wrapper( + trainings.get, + ) class AsyncTrainingsResourceWithStreamingResponse: def __init__(self, trainings: AsyncTrainingsResource) -> None: self._trainings = trainings - self.retrieve = async_to_streamed_response_wrapper( - trainings.retrieve, - ) self.list = async_to_streamed_response_wrapper( trainings.list, ) self.cancel = async_to_streamed_response_wrapper( trainings.cancel, ) + self.get = async_to_streamed_response_wrapper( + trainings.get, + ) diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index e2b3c58..fa8ee5d 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -7,13 +7,16 @@ from .model_create_params import ModelCreateParams as ModelCreateParams from .model_list_response import ModelListResponse as ModelListResponse from .account_list_response import AccountListResponse as AccountListResponse +from .training_get_response import TrainingGetResponse as TrainingGetResponse from .hardware_list_response import HardwareListResponse as HardwareListResponse from .prediction_list_params import PredictionListParams as PredictionListParams +from .training_list_response import TrainingListResponse as TrainingListResponse +from .deployment_get_response import DeploymentGetResponse as DeploymentGetResponse from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams from .deployment_list_response import DeploymentListResponse as DeploymentListResponse from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams from .prediction_create_params import PredictionCreateParams as PredictionCreateParams +from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse -from .deployment_retrieve_response import DeploymentRetrieveResponse as DeploymentRetrieveResponse from .model_create_prediction_params import ModelCreatePredictionParams as ModelCreatePredictionParams diff --git a/src/replicate/types/deployment_retrieve_response.py b/src/replicate/types/deployment_get_response.py similarity index 91% rename from src/replicate/types/deployment_retrieve_response.py rename to src/replicate/types/deployment_get_response.py index 2ecc13c..a7281ff 100644 --- a/src/replicate/types/deployment_retrieve_response.py +++ b/src/replicate/types/deployment_get_response.py @@ -6,7 +6,7 @@ from .._models import BaseModel -__all__ = ["DeploymentRetrieveResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"] +__all__ = ["DeploymentGetResponse", "CurrentRelease", "CurrentReleaseConfiguration", "CurrentReleaseCreatedBy"] class CurrentReleaseConfiguration(BaseModel): @@ -55,7 +55,7 @@ class CurrentRelease(BaseModel): """The ID of the model version used in the release.""" -class DeploymentRetrieveResponse(BaseModel): +class DeploymentGetResponse(BaseModel): current_release: Optional[CurrentRelease] = None name: Optional[str] = None diff --git a/src/replicate/types/deployment_list_response.py b/src/replicate/types/deployment_list_response.py index 9f64487..2606d38 100644 --- a/src/replicate/types/deployment_list_response.py +++ b/src/replicate/types/deployment_list_response.py @@ -49,11 +49,7 @@ class CurrentRelease(BaseModel): """The model identifier string in the format of `{model_owner}/{model_name}`.""" number: Optional[int] = None - """The release number. - - This is an auto-incrementing integer that starts at 1, and is set automatically - when a deployment is created. - """ + """The release number.""" version: Optional[str] = None """The ID of the model version used in the release.""" diff --git a/src/replicate/types/models/__init__.py b/src/replicate/types/models/__init__.py index 98e1a3a..2d89bc6 100644 --- a/src/replicate/types/models/__init__.py +++ b/src/replicate/types/models/__init__.py @@ -3,3 +3,4 @@ from __future__ import annotations from .version_create_training_params import VersionCreateTrainingParams as VersionCreateTrainingParams +from .version_create_training_response import VersionCreateTrainingResponse as VersionCreateTrainingResponse diff --git a/src/replicate/types/models/version_create_training_response.py b/src/replicate/types/models/version_create_training_response.py new file mode 100644 index 0000000..5b005fe --- /dev/null +++ b/src/replicate/types/models/version_create_training_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["VersionCreateTrainingResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class VersionCreateTrainingResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/prediction_output.py b/src/replicate/types/prediction_output.py index 2c46100..8b320b5 100644 --- a/src/replicate/types/prediction_output.py +++ b/src/replicate/types/prediction_output.py @@ -6,5 +6,5 @@ __all__ = ["PredictionOutput"] PredictionOutput: TypeAlias = Union[ - Optional[Dict[str, object]], Optional[List[object]], Optional[str], Optional[float], Optional[bool] + Optional[Dict[str, object]], Optional[List[Dict[str, object]]], Optional[str], Optional[float], Optional[bool] ] diff --git a/src/replicate/types/training_cancel_response.py b/src/replicate/types/training_cancel_response.py new file mode 100644 index 0000000..9af512b --- /dev/null +++ b/src/replicate/types/training_cancel_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingCancelResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingCancelResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/training_get_response.py b/src/replicate/types/training_get_response.py new file mode 100644 index 0000000..4169da7 --- /dev/null +++ b/src/replicate/types/training_get_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingGetResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingGetResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/src/replicate/types/training_list_response.py b/src/replicate/types/training_list_response.py new file mode 100644 index 0000000..02adf3a --- /dev/null +++ b/src/replicate/types/training_list_response.py @@ -0,0 +1,74 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, Optional +from datetime import datetime +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["TrainingListResponse", "Metrics", "Output", "URLs"] + + +class Metrics(BaseModel): + predict_time: Optional[float] = None + """The amount of CPU or GPU time, in seconds, that the training used while running""" + + +class Output(BaseModel): + version: Optional[str] = None + """The version of the model created by the training""" + + weights: Optional[str] = None + """The weights of the trained model""" + + +class URLs(BaseModel): + cancel: Optional[str] = None + """URL to cancel the training""" + + get: Optional[str] = None + """URL to get the training details""" + + +class TrainingListResponse(BaseModel): + id: Optional[str] = None + """The unique ID of the training""" + + completed_at: Optional[datetime] = None + """The time when the training completed""" + + created_at: Optional[datetime] = None + """The time when the training was created""" + + error: Optional[str] = None + """Error message if the training failed""" + + input: Optional[Dict[str, object]] = None + """The input parameters used for the training""" + + logs: Optional[str] = None + """The logs from the training process""" + + metrics: Optional[Metrics] = None + """Metrics about the training process""" + + model: Optional[str] = None + """The name of the model in the format owner/name""" + + output: Optional[Output] = None + """The output of the training process""" + + source: Optional[Literal["web", "api"]] = None + """How the training was created""" + + started_at: Optional[datetime] = None + """The time when the training started""" + + status: Optional[Literal["starting", "processing", "succeeded", "failed", "canceled"]] = None + """The current status of the training""" + + urls: Optional[URLs] = None + """URLs for interacting with the training""" + + version: Optional[str] = None + """The ID of the model version used for training""" diff --git a/tests/api_resources/models/test_versions.py b/tests/api_resources/models/test_versions.py index 295205e..d1fb7a8 100644 --- a/tests/api_resources/models/test_versions.py +++ b/tests/api_resources/models/test_versions.py @@ -8,6 +8,8 @@ import pytest from replicate import ReplicateClient, AsyncReplicateClient +from tests.utils import assert_matches_type +from replicate.types.models import VersionCreateTrainingResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -15,70 +17,6 @@ class TestVersions: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - version = client.models.versions.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) - assert version is None - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = response.parse() - assert version is None - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.models.versions.with_streaming_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="model_name", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - version = response.parse() - assert version is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="", - model_name="model_name", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="version_id", - model_owner="model_owner", - model_name="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): - client.models.versions.with_raw_response.retrieve( - version_id="", - model_owner="model_owner", - model_name="model_name", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -205,7 +143,7 @@ def test_method_create_training(self, client: ReplicateClient) -> None: destination="destination", input={}, ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -219,7 +157,7 @@ def test_method_create_training_with_all_params(self, client: ReplicateClient) - webhook="webhook", webhook_events_filter=["start"], ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -235,7 +173,7 @@ def test_raw_response_create_training(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -251,7 +189,7 @@ def test_streaming_response_create_training(self, client: ReplicateClient) -> No assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) assert cast(Any, response.is_closed) is True @@ -285,14 +223,10 @@ def test_path_params_create_training(self, client: ReplicateClient) -> None: input={}, ) - -class TestAsyncVersions: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - version = await async_client.models.versions.retrieve( + def test_method_get(self, client: ReplicateClient) -> None: + version = client.models.versions.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -301,8 +235,8 @@ async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None @pytest.mark.skip() @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.models.versions.with_raw_response.retrieve( + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -310,13 +244,13 @@ async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) - assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = await response.parse() + version = response.parse() assert version is None @pytest.mark.skip() @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.models.versions.with_streaming_response.retrieve( + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.models.versions.with_streaming_response.get( version_id="version_id", model_owner="model_owner", model_name="model_name", @@ -324,35 +258,39 @@ async def test_streaming_response_retrieve(self, async_client: AsyncReplicateCli assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - version = await response.parse() + version = response.parse() assert version is None assert cast(Any, response.is_closed) is True @pytest.mark.skip() @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: + def test_path_params_get(self, client: ReplicateClient) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="version_id", model_owner="", model_name="model_name", ) with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="version_id", model_owner="model_owner", model_name="", ) with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): - await async_client.models.versions.with_raw_response.retrieve( + client.models.versions.with_raw_response.get( version_id="", model_owner="model_owner", model_name="model_name", ) + +class TestAsyncVersions: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -479,7 +417,7 @@ async def test_method_create_training(self, async_client: AsyncReplicateClient) destination="destination", input={}, ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -493,7 +431,7 @@ async def test_method_create_training_with_all_params(self, async_client: AsyncR webhook="webhook", webhook_events_filter=["start"], ) - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -509,7 +447,7 @@ async def test_raw_response_create_training(self, async_client: AsyncReplicateCl assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = await response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) @pytest.mark.skip() @parametrize @@ -525,7 +463,7 @@ async def test_streaming_response_create_training(self, async_client: AsyncRepli assert response.http_request.headers.get("X-Stainless-Lang") == "python" version = await response.parse() - assert version is None + assert_matches_type(VersionCreateTrainingResponse, version, path=["response"]) assert cast(Any, response.is_closed) is True @@ -558,3 +496,67 @@ async def test_path_params_create_training(self, async_client: AsyncReplicateCli destination="destination", input={}, ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + version = await async_client.models.versions.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) + assert version is None + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + version = await response.parse() + assert version is None + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.models.versions.with_streaming_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="model_name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + version = await response.parse() + assert version is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="", + model_name="model_name", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="version_id", + model_owner="model_owner", + model_name="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `version_id` but received ''"): + await async_client.models.versions.with_raw_response.get( + version_id="", + model_owner="model_owner", + model_name="model_name", + ) diff --git a/tests/api_resources/test_deployments.py b/tests/api_resources/test_deployments.py index b834241..3a01dcb 100644 --- a/tests/api_resources/test_deployments.py +++ b/tests/api_resources/test_deployments.py @@ -10,10 +10,10 @@ from replicate import ReplicateClient, AsyncReplicateClient from tests.utils import assert_matches_type from replicate.types import ( + DeploymentGetResponse, DeploymentListResponse, DeploymentCreateResponse, DeploymentUpdateResponse, - DeploymentRetrieveResponse, ) from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage @@ -72,58 +72,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - deployment = client.deployments.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - deployment = response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.deployments.with_streaming_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - deployment = response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): - client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): - client.deployments.with_raw_response.retrieve( - deployment_name="", - deployment_owner="deployment_owner", - ) - @pytest.mark.skip() @parametrize def test_method_update(self, client: ReplicateClient) -> None: @@ -269,6 +217,58 @@ def test_path_params_delete(self, client: ReplicateClient) -> None: deployment_owner="deployment_owner", ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + deployment = client.deployments.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + deployment = response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.deployments.with_streaming_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + deployment = response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): + client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): + client.deployments.with_raw_response.get( + deployment_name="", + deployment_owner="deployment_owner", + ) + @pytest.mark.skip() @parametrize def test_method_list_em_all(self, client: ReplicateClient) -> None: @@ -350,58 +350,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - deployment = await async_client.deployments.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - deployment = await response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.deployments.with_streaming_response.retrieve( - deployment_name="deployment_name", - deployment_owner="deployment_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - deployment = await response.parse() - assert_matches_type(DeploymentRetrieveResponse, deployment, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): - await async_client.deployments.with_raw_response.retrieve( - deployment_name="deployment_name", - deployment_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): - await async_client.deployments.with_raw_response.retrieve( - deployment_name="", - deployment_owner="deployment_owner", - ) - @pytest.mark.skip() @parametrize async def test_method_update(self, async_client: AsyncReplicateClient) -> None: @@ -547,6 +495,58 @@ async def test_path_params_delete(self, async_client: AsyncReplicateClient) -> N deployment_owner="deployment_owner", ) + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + deployment = await async_client.deployments.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + deployment = await response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.deployments.with_streaming_response.get( + deployment_name="deployment_name", + deployment_owner="deployment_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + deployment = await response.parse() + assert_matches_type(DeploymentGetResponse, deployment, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_owner` but received ''"): + await async_client.deployments.with_raw_response.get( + deployment_name="deployment_name", + deployment_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `deployment_name` but received ''"): + await async_client.deployments.with_raw_response.get( + deployment_name="", + deployment_owner="deployment_owner", + ) + @pytest.mark.skip() @parametrize async def test_method_list_em_all(self, async_client: AsyncReplicateClient) -> None: diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index f56f00c..a2a9f52 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -77,58 +77,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - model = client.models.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - assert model is None - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - model = response.parse() - assert model is None - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.models.with_streaming_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - model = response.parse() - assert model is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - client.models.with_raw_response.retrieve( - model_name="", - model_owner="model_owner", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -280,6 +228,58 @@ def test_path_params_create_prediction(self, client: ReplicateClient) -> None: input={}, ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + model = client.models.get( + model_name="model_name", + model_owner="model_owner", + ) + assert model is None + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.models.with_raw_response.get( + model_name="model_name", + model_owner="model_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert model is None + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.models.with_streaming_response.get( + model_name="model_name", + model_owner="model_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + client.models.with_raw_response.get( + model_name="model_name", + model_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + client.models.with_raw_response.get( + model_name="", + model_owner="model_owner", + ) + class TestAsyncModels: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -343,58 +343,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - model = await async_client.models.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - assert model is None - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - model = await response.parse() - assert model is None - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.models.with_streaming_response.retrieve( - model_name="model_name", - model_owner="model_owner", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - model = await response.parse() - assert model is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): - await async_client.models.with_raw_response.retrieve( - model_name="model_name", - model_owner="", - ) - - with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): - await async_client.models.with_raw_response.retrieve( - model_name="", - model_owner="model_owner", - ) - @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -545,3 +493,55 @@ async def test_path_params_create_prediction(self, async_client: AsyncReplicateC model_owner="model_owner", input={}, ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + model = await async_client.models.get( + model_name="model_name", + model_owner="model_owner", + ) + assert model is None + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.models.with_raw_response.get( + model_name="model_name", + model_owner="model_owner", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert model is None + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.models.with_streaming_response.get( + model_name="model_name", + model_owner="model_owner", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert model is None + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_owner` but received ''"): + await async_client.models.with_raw_response.get( + model_name="model_name", + model_owner="", + ) + + with pytest.raises(ValueError, match=r"Expected a non-empty value for `model_name` but received ''"): + await async_client.models.with_raw_response.get( + model_name="", + model_owner="model_owner", + ) diff --git a/tests/api_resources/test_predictions.py b/tests/api_resources/test_predictions.py index 8fc4142..51bd80d 100644 --- a/tests/api_resources/test_predictions.py +++ b/tests/api_resources/test_predictions.py @@ -69,48 +69,6 @@ def test_streaming_response_create(self, client: ReplicateClient) -> None: assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - prediction = client.predictions.retrieve( - "prediction_id", - ) - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.predictions.with_raw_response.retrieve( - "prediction_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prediction = response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.predictions.with_streaming_response.retrieve( - "prediction_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - prediction = response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): - client.predictions.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: @@ -190,6 +148,48 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: "", ) + @pytest.mark.skip() + @parametrize + def test_method_get(self, client: ReplicateClient) -> None: + prediction = client.predictions.get( + "prediction_id", + ) + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.predictions.with_raw_response.get( + "prediction_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + prediction = response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.predictions.with_streaming_response.get( + "prediction_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prediction = response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + def test_path_params_get(self, client: ReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): + client.predictions.with_raw_response.get( + "", + ) + class TestAsyncPredictions: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @@ -244,48 +244,6 @@ async def test_streaming_response_create(self, async_client: AsyncReplicateClien assert cast(Any, response.is_closed) is True - @pytest.mark.skip() - @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - prediction = await async_client.predictions.retrieve( - "prediction_id", - ) - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.predictions.with_raw_response.retrieve( - "prediction_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prediction = await response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - @pytest.mark.skip() - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.predictions.with_streaming_response.retrieve( - "prediction_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - prediction = await response.parse() - assert_matches_type(Prediction, prediction, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): - await async_client.predictions.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: @@ -364,3 +322,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N await async_client.predictions.with_raw_response.cancel( "", ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + prediction = await async_client.predictions.get( + "prediction_id", + ) + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.predictions.with_raw_response.get( + "prediction_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + prediction = await response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.predictions.with_streaming_response.get( + "prediction_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prediction = await response.parse() + assert_matches_type(Prediction, prediction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `prediction_id` but received ''"): + await async_client.predictions.with_raw_response.get( + "", + ) diff --git a/tests/api_resources/test_trainings.py b/tests/api_resources/test_trainings.py index f59d874..f2dadb1 100644 --- a/tests/api_resources/test_trainings.py +++ b/tests/api_resources/test_trainings.py @@ -8,6 +8,9 @@ import pytest from replicate import ReplicateClient, AsyncReplicateClient +from tests.utils import assert_matches_type +from replicate.types import TrainingGetResponse, TrainingListResponse, TrainingCancelResponse +from replicate.pagination import SyncCursorURLPage, AsyncCursorURLPage base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -15,53 +18,11 @@ class TestTrainings: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() - @parametrize - def test_method_retrieve(self, client: ReplicateClient) -> None: - training = client.trainings.retrieve( - "training_id", - ) - assert training is None - - @pytest.mark.skip() - @parametrize - def test_raw_response_retrieve(self, client: ReplicateClient) -> None: - response = client.trainings.with_raw_response.retrieve( - "training_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = response.parse() - assert training is None - - @pytest.mark.skip() - @parametrize - def test_streaming_response_retrieve(self, client: ReplicateClient) -> None: - with client.trainings.with_streaming_response.retrieve( - "training_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - training = response.parse() - assert training is None - - assert cast(Any, response.is_closed) is True - - @pytest.mark.skip() - @parametrize - def test_path_params_retrieve(self, client: ReplicateClient) -> None: - with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): - client.trainings.with_raw_response.retrieve( - "", - ) - @pytest.mark.skip() @parametrize def test_method_list(self, client: ReplicateClient) -> None: training = client.trainings.list() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -71,7 +32,7 @@ def test_raw_response_list(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -81,7 +42,7 @@ def test_streaming_response_list(self, client: ReplicateClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(SyncCursorURLPage[TrainingListResponse], training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -91,7 +52,7 @@ def test_method_cancel(self, client: ReplicateClient) -> None: training = client.trainings.cancel( "training_id", ) - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -103,7 +64,7 @@ def test_raw_response_cancel(self, client: ReplicateClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -115,7 +76,7 @@ def test_streaming_response_cancel(self, client: ReplicateClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -127,57 +88,57 @@ def test_path_params_cancel(self, client: ReplicateClient) -> None: "", ) - -class TestAsyncTrainings: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip() @parametrize - async def test_method_retrieve(self, async_client: AsyncReplicateClient) -> None: - training = await async_client.trainings.retrieve( + def test_method_get(self, client: ReplicateClient) -> None: + training = client.trainings.get( "training_id", ) - assert training is None + assert_matches_type(TrainingGetResponse, training, path=["response"]) @pytest.mark.skip() @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - response = await async_client.trainings.with_raw_response.retrieve( + def test_raw_response_get(self, client: ReplicateClient) -> None: + response = client.trainings.with_raw_response.get( "training_id", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = await response.parse() - assert training is None + training = response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) @pytest.mark.skip() @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncReplicateClient) -> None: - async with async_client.trainings.with_streaming_response.retrieve( + def test_streaming_response_get(self, client: ReplicateClient) -> None: + with client.trainings.with_streaming_response.get( "training_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - training = await response.parse() - assert training is None + training = response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @pytest.mark.skip() @parametrize - async def test_path_params_retrieve(self, async_client: AsyncReplicateClient) -> None: + def test_path_params_get(self, client: ReplicateClient) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): - await async_client.trainings.with_raw_response.retrieve( + client.trainings.with_raw_response.get( "", ) + +class TestAsyncTrainings: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip() @parametrize async def test_method_list(self, async_client: AsyncReplicateClient) -> None: training = await async_client.trainings.list() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -187,7 +148,7 @@ async def test_raw_response_list(self, async_client: AsyncReplicateClient) -> No assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) @pytest.mark.skip() @parametrize @@ -197,7 +158,7 @@ async def test_streaming_response_list(self, async_client: AsyncReplicateClient) assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(AsyncCursorURLPage[TrainingListResponse], training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -207,7 +168,7 @@ async def test_method_cancel(self, async_client: AsyncReplicateClient) -> None: training = await async_client.trainings.cancel( "training_id", ) - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -219,7 +180,7 @@ async def test_raw_response_cancel(self, async_client: AsyncReplicateClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) @pytest.mark.skip() @parametrize @@ -231,7 +192,7 @@ async def test_streaming_response_cancel(self, async_client: AsyncReplicateClien assert response.http_request.headers.get("X-Stainless-Lang") == "python" training = await response.parse() - assert training is None + assert_matches_type(TrainingCancelResponse, training, path=["response"]) assert cast(Any, response.is_closed) is True @@ -242,3 +203,45 @@ async def test_path_params_cancel(self, async_client: AsyncReplicateClient) -> N await async_client.trainings.with_raw_response.cancel( "", ) + + @pytest.mark.skip() + @parametrize + async def test_method_get(self, async_client: AsyncReplicateClient) -> None: + training = await async_client.trainings.get( + "training_id", + ) + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_raw_response_get(self, async_client: AsyncReplicateClient) -> None: + response = await async_client.trainings.with_raw_response.get( + "training_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + training = await response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_get(self, async_client: AsyncReplicateClient) -> None: + async with async_client.trainings.with_streaming_response.get( + "training_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + training = await response.parse() + assert_matches_type(TrainingGetResponse, training, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_path_params_get(self, async_client: AsyncReplicateClient) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `training_id` but received ''"): + await async_client.trainings.with_raw_response.get( + "", + ) diff --git a/tests/conftest.py b/tests/conftest.py index 97007f1..7d1538b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,7 +10,7 @@ from replicate import ReplicateClient, AsyncReplicateClient if TYPE_CHECKING: - from _pytest.fixtures import FixtureRequest + from _pytest.fixtures import FixtureRequest # pyright: ignore[reportPrivateImportUsage] pytest.register_assert_rewrite("tests.utils") diff --git a/tests/test_client.py b/tests/test_client.py index bd5b999..279675f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -346,7 +346,7 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {bearer_token}" with pytest.raises(ReplicateClientError): - with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}): + with update_env(**{"REPLICATE_API_TOKEN": Omit()}): client2 = ReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 @@ -1126,7 +1126,7 @@ def test_validate_headers(self) -> None: assert request.headers.get("Authorization") == f"Bearer {bearer_token}" with pytest.raises(ReplicateClientError): - with update_env(**{"REPLICATE_CLIENT_BEARER_TOKEN": Omit()}): + with update_env(**{"REPLICATE_API_TOKEN": Omit()}): client2 = AsyncReplicateClient(base_url=base_url, bearer_token=None, _strict_response_validation=True) _ = client2 diff --git a/tests/test_models.py b/tests/test_models.py index ee374fa..cf8173c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -492,12 +492,15 @@ class Model(BaseModel): resource_id: Optional[str] = None m = Model.construct() + assert m.resource_id is None assert "resource_id" not in m.model_fields_set m = Model.construct(resource_id=None) + assert m.resource_id is None assert "resource_id" in m.model_fields_set m = Model.construct(resource_id="foo") + assert m.resource_id == "foo" assert "resource_id" in m.model_fields_set @@ -832,7 +835,7 @@ class B(BaseModel): @pytest.mark.skipif(not PYDANTIC_V2, reason="TypeAliasType is not supported in Pydantic v1") def test_type_alias_type() -> None: - Alias = TypeAliasType("Alias", str) + Alias = TypeAliasType("Alias", str) # pyright: ignore class Model(BaseModel): alias: Alias