diff --git a/.release-please-manifest.json b/.release-please-manifest.json index e429c67..01c33a5 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "2.0.0-alpha.25" + ".": "2.0.0-alpha.26" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index a536750..0277638 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ -configured_endpoints: 35 +configured_endpoints: 36 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/replicate%2Freplicate-client-87c7c57bd75c54990c679c9e87d009851cdff572815a55d1b6ee4d4ee20adaa1.yml openapi_spec_hash: d987f14befa536004eece7b49caad993 -config_hash: b1b4f5d24ba07b4667ffe7b9dec081e3 +config_hash: a916e7f3559ab312c7b6696cd6b35fb5 diff --git a/CHANGELOG.md b/CHANGELOG.md index a7b7d12..2bbfc35 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## 2.0.0-alpha.26 (2025-09-17) + +Full Changelog: [v2.0.0-alpha.25...v2.0.0-alpha.26](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.25...v2.0.0-alpha.26) + +### Features + +* **api:** add new replicate.search() method (beta) ([30d7019](https://github.com/replicate/replicate-python-stainless/commit/30d701999ea48ee65c5e5fd467072ccd5db35f87)) + + +### Bug Fixes + +* **tests:** fix tests for module-level client ([1e72f23](https://github.com/replicate/replicate-python-stainless/commit/1e72f23da3f0930955fe126848a8a8e58dbb710e)) + + +### Chores + +* **internal:** update pydantic dependency ([54872cb](https://github.com/replicate/replicate-python-stainless/commit/54872cb5e00fb65cae2abffcf0169a8b138e35fa)) + ## 2.0.0-alpha.25 (2025-09-15) Full Changelog: [v2.0.0-alpha.24...v2.0.0-alpha.25](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.24...v2.0.0-alpha.25) diff --git a/api.md b/api.md index 9bf36c7..d037762 100644 --- a/api.md +++ b/api.md @@ -1,3 +1,15 @@ +# Replicate + +Types: + +```python +from replicate.types import SearchResponse +``` + +Methods: + +- replicate.search(\*\*params) -> SearchResponse + # Collections Types: diff --git a/pyproject.toml b/pyproject.toml index c0ff437..044be7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "replicate" -version = "2.0.0-alpha.25" +version = "2.0.0-alpha.26" description = "The official Python library for the replicate API" dynamic = ["readme"] license = "Apache-2.0" diff --git a/requirements-dev.lock b/requirements-dev.lock index 22fdd88..839ba5d 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -88,9 +88,9 @@ pluggy==1.5.0 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via replicate -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic pygments==2.18.0 # via rich @@ -126,6 +126,9 @@ typing-extensions==4.12.2 # via pydantic-core # via pyright # via replicate + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic virtualenv==20.24.5 # via nox yarl==1.20.0 diff --git a/requirements.lock b/requirements.lock index d76f084..9c126e0 100644 --- a/requirements.lock +++ b/requirements.lock @@ -55,9 +55,9 @@ multidict==6.4.4 propcache==0.3.1 # via aiohttp # via yarl -pydantic==2.10.3 +pydantic==2.11.9 # via replicate -pydantic-core==2.27.1 +pydantic-core==2.33.2 # via pydantic sniffio==1.3.0 # via anyio @@ -68,5 +68,8 @@ typing-extensions==4.12.2 # via pydantic # via pydantic-core # via replicate + # via typing-inspection +typing-inspection==0.4.1 + # via pydantic yarl==1.20.0 # via aiohttp diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 7113e54..390a552 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -27,25 +27,42 @@ from . import _exceptions from ._qs import Querystring +from .types import client_search_params from ._types import ( NOT_GIVEN, + Body, Omit, + Query, + Headers, Timeout, NotGiven, Transport, ProxiesTypes, RequestOptions, ) -from ._utils import is_given, get_async_library +from ._utils import ( + is_given, + maybe_transform, + get_async_library, + async_maybe_transform, +) from ._compat import cached_property from ._version import __version__ +from ._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ._streaming import Stream as Stream, AsyncStream as AsyncStream from ._exceptions import APIStatusError, ReplicateError from ._base_client import ( DEFAULT_MAX_RETRIES, SyncAPIClient, AsyncAPIClient, + make_request_options, ) +from .types.search_response import SearchResponse if TYPE_CHECKING: from .resources import files, models, account, hardware, webhooks, trainings, collections, deployments, predictions @@ -354,6 +371,70 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + def search( + self, + *, + query: str, + limit: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> SearchResponse: + """ + Search for public models, collections, and docs using a text query. + + For models, the response includes all model data, plus a new `metadata` object + with the following fields: + + - `generated_description`: A longer and more detailed AI-generated description + of the model + - `tags`: An array of tags for the model + - `score`: A score for the model's relevance to the search query + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + "https://api.replicate.com/v1/search?query=nano+banana" + ``` + + Note: This search API is currently in beta and may change in future versions. + + Args: + query: The search query string + + limit: Maximum number of model results to return (1-50, defaults to 20) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self.get( + "/search", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "query": query, + "limit": limit, + }, + client_search_params.ClientSearchParams, + ), + ), + cast_to=SearchResponse, + ) + @override def _make_status_error( self, @@ -665,6 +746,70 @@ def copy( # client.with_options(timeout=10).foo.create(...) with_options = copy + async def search( + self, + *, + query: str, + limit: int | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> SearchResponse: + """ + Search for public models, collections, and docs using a text query. + + For models, the response includes all model data, plus a new `metadata` object + with the following fields: + + - `generated_description`: A longer and more detailed AI-generated description + of the model + - `tags`: An array of tags for the model + - `score`: A score for the model's relevance to the search query + + Example cURL request: + + ```console + curl -s \\ + -H "Authorization: Bearer $REPLICATE_API_TOKEN" \\ + "https://api.replicate.com/v1/search?query=nano+banana" + ``` + + Note: This search API is currently in beta and may change in future versions. + + Args: + query: The search query string + + limit: Maximum number of model results to return (1-50, defaults to 20) + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self.get( + "/search", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + { + "query": query, + "limit": limit, + }, + client_search_params.ClientSearchParams, + ), + ), + cast_to=SearchResponse, + ) + @override def _make_status_error( self, @@ -705,6 +850,10 @@ class ReplicateWithRawResponse: def __init__(self, client: Replicate) -> None: self._client = client + self.search = to_raw_response_wrapper( + client.search, + ) + @cached_property def collections(self) -> collections.CollectionsResourceWithRawResponse: from .resources.collections import CollectionsResourceWithRawResponse @@ -766,6 +915,10 @@ class AsyncReplicateWithRawResponse: def __init__(self, client: AsyncReplicate) -> None: self._client = client + self.search = async_to_raw_response_wrapper( + client.search, + ) + @cached_property def collections(self) -> collections.AsyncCollectionsResourceWithRawResponse: from .resources.collections import AsyncCollectionsResourceWithRawResponse @@ -827,6 +980,10 @@ class ReplicateWithStreamedResponse: def __init__(self, client: Replicate) -> None: self._client = client + self.search = to_streamed_response_wrapper( + client.search, + ) + @cached_property def collections(self) -> collections.CollectionsResourceWithStreamingResponse: from .resources.collections import CollectionsResourceWithStreamingResponse @@ -888,6 +1045,10 @@ class AsyncReplicateWithStreamedResponse: def __init__(self, client: AsyncReplicate) -> None: self._client = client + self.search = async_to_streamed_response_wrapper( + client.search, + ) + @cached_property def collections(self) -> collections.AsyncCollectionsResourceWithStreamingResponse: from .resources.collections import AsyncCollectionsResourceWithStreamingResponse diff --git a/src/replicate/_models.py b/src/replicate/_models.py index 3a6017e..6a3cd1d 100644 --- a/src/replicate/_models.py +++ b/src/replicate/_models.py @@ -256,7 +256,7 @@ def model_dump( mode: Literal["json", "python"] | str = "python", include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, @@ -264,6 +264,7 @@ def model_dump( warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, serialize_as_any: bool = False, + fallback: Callable[[Any], Any] | None = None, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump @@ -295,10 +296,12 @@ def model_dump( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") dumped = super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, @@ -313,13 +316,14 @@ def model_dump_json( indent: int | None = None, include: IncEx | None = None, exclude: IncEx | None = None, - by_alias: bool = False, + by_alias: bool | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: bool | Literal["none", "warn", "error"] = True, context: dict[str, Any] | None = None, + fallback: Callable[[Any], Any] | None = None, serialize_as_any: bool = False, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json @@ -348,11 +352,13 @@ def model_dump_json( raise ValueError("context is only supported in Pydantic v2") if serialize_as_any != False: raise ValueError("serialize_as_any is only supported in Pydantic v2") + if fallback is not None: + raise ValueError("fallback is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, - by_alias=by_alias, + by_alias=by_alias if by_alias is not None else False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, diff --git a/src/replicate/_version.py b/src/replicate/_version.py index 8060a57..52084c8 100644 --- a/src/replicate/_version.py +++ b/src/replicate/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "replicate" -__version__ = "2.0.0-alpha.25" # x-release-please-version +__version__ = "2.0.0-alpha.26" # x-release-please-version diff --git a/src/replicate/types/__init__.py b/src/replicate/types/__init__.py index ea8b27c..3f375e8 100644 --- a/src/replicate/types/__init__.py +++ b/src/replicate/types/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from .prediction import Prediction as Prediction +from .search_response import SearchResponse as SearchResponse from .file_get_response import FileGetResponse as FileGetResponse from .file_create_params import FileCreateParams as FileCreateParams from .file_list_response import FileListResponse as FileListResponse @@ -11,6 +12,7 @@ from .model_list_response import ModelListResponse as ModelListResponse from .model_search_params import ModelSearchParams as ModelSearchParams from .account_get_response import AccountGetResponse as AccountGetResponse +from .client_search_params import ClientSearchParams as ClientSearchParams from .file_create_response import FileCreateResponse as FileCreateResponse from .file_download_params import FileDownloadParams as FileDownloadParams from .model_create_response import ModelCreateResponse as ModelCreateResponse diff --git a/src/replicate/types/client_search_params.py b/src/replicate/types/client_search_params.py new file mode 100644 index 0000000..da7f176 --- /dev/null +++ b/src/replicate/types/client_search_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["ClientSearchParams"] + + +class ClientSearchParams(TypedDict, total=False): + query: Required[str] + """The search query string""" + + limit: int + """Maximum number of model results to return (1-50, defaults to 20)""" diff --git a/src/replicate/types/search_response.py b/src/replicate/types/search_response.py new file mode 100644 index 0000000..ac78880 --- /dev/null +++ b/src/replicate/types/search_response.py @@ -0,0 +1,109 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from .._models import BaseModel + +__all__ = ["SearchResponse", "Collection", "Model", "ModelMetadata", "ModelModel", "Page"] + + +class Collection(BaseModel): + description: str + """A description of the collection""" + + name: str + """The name of the collection""" + + slug: str + """The slug of the collection (lowercase with dashes)""" + + models: Optional[List[str]] = None + """Array of model names in the collection""" + + +class ModelMetadata(BaseModel): + generated_description: Optional[str] = None + """AI-generated detailed description of the model""" + + score: Optional[float] = None + """Search relevance score""" + + tags: Optional[List[str]] = None + """Array of descriptive tags for the model""" + + +class ModelModel(BaseModel): + cover_image_url: Optional[str] = None + """A URL for the model's cover image""" + + default_example: Optional[object] = None + """The model's default example prediction""" + + description: Optional[str] = None + """A description of the model""" + + github_url: Optional[str] = None + """A URL for the model's source code on GitHub""" + + is_official: Optional[bool] = None + """Boolean indicating whether the model is officially maintained by Replicate. + + Official models are always on, have stable API interfaces, and predictable + pricing. + """ + + latest_version: Optional[object] = None + """The model's latest version""" + + license_url: Optional[str] = None + """A URL for the model's license""" + + name: Optional[str] = None + """The name of the model""" + + owner: Optional[str] = None + """The name of the user or organization that owns the model""" + + paper_url: Optional[str] = None + """A URL for the model's paper""" + + run_count: Optional[int] = None + """The number of times the model has been run""" + + url: Optional[str] = None + """The URL of the model on Replicate""" + + visibility: Optional[Literal["public", "private"]] = None + """Whether the model is public or private""" + + +class Model(BaseModel): + metadata: ModelMetadata + + model: ModelModel + + +class Page(BaseModel): + href: str + """URL path to the page""" + + name: str + """Title of the page""" + + +class SearchResponse(BaseModel): + collections: List[Collection] + """Array of collections that match the search query""" + + models: List[Model] + """ + Array of models that match the search query, each containing model data and + extra metadata + """ + + pages: List[Page] + """Array of Replicate documentation pages that match the search query""" + + query: str + """The search term that was evaluated""" diff --git a/tests/api_resources/test_client.py b/tests/api_resources/test_client.py new file mode 100644 index 0000000..a278be6 --- /dev/null +++ b/tests/api_resources/test_client.py @@ -0,0 +1,110 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from replicate import Replicate, AsyncReplicate +from tests.utils import assert_matches_type +from replicate.types import SearchResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestClient: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + def test_method_search(self, client: Replicate) -> None: + client_ = client.search( + query="nano banana", + ) + assert_matches_type(SearchResponse, client_, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + def test_method_search_with_all_params(self, client: Replicate) -> None: + client_ = client.search( + query="nano banana", + limit=10, + ) + assert_matches_type(SearchResponse, client_, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + def test_raw_response_search(self, client: Replicate) -> None: + response = client.with_raw_response.search( + query="nano banana", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + client_ = response.parse() + assert_matches_type(SearchResponse, client_, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + def test_streaming_response_search(self, client: Replicate) -> None: + with client.with_streaming_response.search( + query="nano banana", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + client_ = response.parse() + assert_matches_type(SearchResponse, client_, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncClient: + parametrize = pytest.mark.parametrize( + "async_client", [False, True, {"http_client": "aiohttp"}], indirect=True, ids=["loose", "strict", "aiohttp"] + ) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + async def test_method_search(self, async_client: AsyncReplicate) -> None: + client = await async_client.search( + query="nano banana", + ) + assert_matches_type(SearchResponse, client, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + async def test_method_search_with_all_params(self, async_client: AsyncReplicate) -> None: + client = await async_client.search( + query="nano banana", + limit=10, + ) + assert_matches_type(SearchResponse, client, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + async def test_raw_response_search(self, async_client: AsyncReplicate) -> None: + response = await async_client.with_raw_response.search( + query="nano banana", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + client = await response.parse() + assert_matches_type(SearchResponse, client, path=["response"]) + + @pytest.mark.skip(reason="Prism tests are disabled") + @parametrize + async def test_streaming_response_search(self, async_client: AsyncReplicate) -> None: + async with async_client.with_streaming_response.search( + query="nano banana", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + client = await response.parse() + assert_matches_type(SearchResponse, client, path=["response"]) + + assert cast(Any, response.is_closed) is True