diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3b286e5..81f6dc2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,6 @@ jobs: lint: name: lint runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 @@ -33,7 +32,6 @@ jobs: test: name: test runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 diff --git a/CHANGELOG.md b/CHANGELOG.md index a37ea99..3e1bb8b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Full Changelog: [v0.0.1-alpha.0...v0.1.0-alpha.1](https://github.com/replicate/r * **api:** manual updates ([#22](https://github.com/replicate/replicate-python-stainless/issues/22)) ([573bbb3](https://github.com/replicate/replicate-python-stainless/commit/573bbb3fa46346d1703d076daa85541b947c27a2)) * **api:** manual updates ([#23](https://github.com/replicate/replicate-python-stainless/issues/23)) ([7962ea7](https://github.com/replicate/replicate-python-stainless/commit/7962ea7f5c5e3c8f253f66b8f13eadc9927fae7d)) * **api:** manual updates ([#24](https://github.com/replicate/replicate-python-stainless/issues/24)) ([d31ada3](https://github.com/replicate/replicate-python-stainless/commit/d31ada3ab8d19413c1ca55535f1788ae9b5d443d)) +* **api:** re-enable module client ([4f83487](https://github.com/replicate/replicate-python-stainless/commit/4f8348757c683395344c8dea1adb75af941f67e1)) * **api:** update pagination configs ([#25](https://github.com/replicate/replicate-python-stainless/issues/25)) ([8a2cc9f](https://github.com/replicate/replicate-python-stainless/commit/8a2cc9f87cf6edb18ad906bbb0b82372b4b82099)) * **api:** update via SDK Studio ([3bf3415](https://github.com/replicate/replicate-python-stainless/commit/3bf3415ce21bb9fc55f80809239e58d64f34fb61)) * **api:** update via SDK Studio ([aafbabf](https://github.com/replicate/replicate-python-stainless/commit/aafbabfdbac1c43a547f277769c82585c616a3b4)) @@ -27,16 +28,27 @@ Full Changelog: [v0.0.1-alpha.0...v0.1.0-alpha.1](https://github.com/replicate/r * **ci:** ensure pip is always available ([#14](https://github.com/replicate/replicate-python-stainless/issues/14)) ([d4f8f18](https://github.com/replicate/replicate-python-stainless/commit/d4f8f18d0d369dedc9e09551e87483f8f1787fd7)) * **ci:** remove publishing patch ([#15](https://github.com/replicate/replicate-python-stainless/issues/15)) ([002b758](https://github.com/replicate/replicate-python-stainless/commit/002b7581debae5b5bc97ec24cc0cea129100dfde)) +* **perf:** optimize some hot paths ([d14ecac](https://github.com/replicate/replicate-python-stainless/commit/d14ecaca89424e6aac85827ae0efd5e6da7d8c03)) +* **perf:** skip traversing types for NotGiven values ([35ce48c](https://github.com/replicate/replicate-python-stainless/commit/35ce48c5e3e93605a70f90270487e02e121561b5)) +* pluralize `list` response variables ([#26](https://github.com/replicate/replicate-python-stainless/issues/26)) ([19f7bd9](https://github.com/replicate/replicate-python-stainless/commit/19f7bd9fe7d8422ae7abe6d5070e29f4a572a1d7)) * **types:** handle more discriminated union shapes ([#13](https://github.com/replicate/replicate-python-stainless/issues/13)) ([4ca1ca8](https://github.com/replicate/replicate-python-stainless/commit/4ca1ca8606fef7bc40144c1ae6246784ed754687)) ### Chores +* **client:** minor internal fixes ([21d9c7b](https://github.com/replicate/replicate-python-stainless/commit/21d9c7be7c1c5cb93466b6f6d16804a7d38701b2)) * fix typos ([#18](https://github.com/replicate/replicate-python-stainless/issues/18)) ([54d4e6d](https://github.com/replicate/replicate-python-stainless/commit/54d4e6da6a757ad6e5c899b79018ca39eed2a124)) * go live ([#1](https://github.com/replicate/replicate-python-stainless/issues/1)) ([bd9a84a](https://github.com/replicate/replicate-python-stainless/commit/bd9a84ae91a2c72fa49182f9ff8ea0b78e7cc343)) * **internal:** bump rye to 0.44.0 ([#12](https://github.com/replicate/replicate-python-stainless/issues/12)) ([c9d2593](https://github.com/replicate/replicate-python-stainless/commit/c9d2593f4e9d324c12e2ca323563a317d0ea1751)) * **internal:** codegen related update ([#11](https://github.com/replicate/replicate-python-stainless/issues/11)) ([41c787d](https://github.com/replicate/replicate-python-stainless/commit/41c787d2c196d59c48ec3056d93dd48bf530de72)) +* **internal:** expand CI branch coverage ([183a5dc](https://github.com/replicate/replicate-python-stainless/commit/183a5dc798f30ff70e7e96ecb3b5c8ade704af6a)) +* **internal:** fix module client tests ([fc9ac15](https://github.com/replicate/replicate-python-stainless/commit/fc9ac15ad5580b72719da024d95e067316f04ef2)) +* **internal:** reduce CI branch coverage ([2c61820](https://github.com/replicate/replicate-python-stainless/commit/2c6182009bbee43f3995be0844983c18e4b147fb)) * **internal:** remove extra empty newlines ([#10](https://github.com/replicate/replicate-python-stainless/issues/10)) ([1c63514](https://github.com/replicate/replicate-python-stainless/commit/1c635145a8a70f3cb0ea432387f72eafa64ca5f6)) +* **internal:** remove trailing character ([#27](https://github.com/replicate/replicate-python-stainless/issues/27)) ([0a6e3f2](https://github.com/replicate/replicate-python-stainless/commit/0a6e3f29ddeecd0267c57536cf0fa07218c18d03)) +* **internal:** slight transform perf improvement ([#28](https://github.com/replicate/replicate-python-stainless/issues/28)) ([da30360](https://github.com/replicate/replicate-python-stainless/commit/da303609abb6a7d069b4e35036f68bae7b009cb2)) +* **internal:** update pyright settings ([65494e5](https://github.com/replicate/replicate-python-stainless/commit/65494e5b623b3d2cf248b6da88c7f3e50160dc94)) * **internal:** updates ([b7424d7](https://github.com/replicate/replicate-python-stainless/commit/b7424d7cc1c9fa440b58fa9e51331f3c77fbd83d)) * remove custom code ([31aa7ed](https://github.com/replicate/replicate-python-stainless/commit/31aa7edc04d5c9f408c06e4051c0ba343b9761ac)) +* sync repo ([9a71c71](https://github.com/replicate/replicate-python-stainless/commit/9a71c71fd94f1e216398c616dc6c24e0eff10c89)) * update SDK settings ([#3](https://github.com/replicate/replicate-python-stainless/issues/3)) ([27b5f18](https://github.com/replicate/replicate-python-stainless/commit/27b5f1897b823349b29dbc82b1f6742d5d704c9e)) diff --git a/README.md b/README.md index e5edec0..0015b4c 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ client = ReplicateClient( ), # This is the default and can be omitted ) -account = client.accounts.list() -print(account.type) +accounts = client.accounts.list() +print(accounts.type) ``` While you can provide a `bearer_token` keyword argument, @@ -59,8 +59,8 @@ client = AsyncReplicateClient( async def main() -> None: - account = await client.accounts.list() - print(account.type) + accounts = await client.accounts.list() + print(accounts.type) asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 0bf058e..1e9a8a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ exclude = [ ] reportImplicitOverride = true +reportOverlappingOverload = false reportImportCycles = false reportPrivateUsage = false diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index b4c9b81..fa787bd 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -1,5 +1,9 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +from __future__ import annotations + +from typing_extensions import override + from . import types from ._types import NOT_GIVEN, Omit, NoneType, NotGiven, Transport, ProxiesTypes from ._utils import file_from_path @@ -92,3 +96,146 @@ except (TypeError, AttributeError): # Some of our exported symbols are builtins which we can't set attributes for. pass + +# ------ Module level client ------ +import typing as _t + +import httpx as _httpx + +from ._base_client import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES + +bearer_token: str | None = None + +base_url: str | _httpx.URL | None = None + +timeout: float | Timeout | None = DEFAULT_TIMEOUT + +max_retries: int = DEFAULT_MAX_RETRIES + +default_headers: _t.Mapping[str, str] | None = None + +default_query: _t.Mapping[str, object] | None = None + +http_client: _httpx.Client | None = None + + +class _ModuleClient(ReplicateClient): + # Note: we have to use type: ignores here as overriding class members + # with properties is technically unsafe but it is fine for our use case + + @property # type: ignore + @override + def bearer_token(self) -> str | None: + return bearer_token + + @bearer_token.setter # type: ignore + def bearer_token(self, value: str | None) -> None: # type: ignore + global bearer_token + + bearer_token = value + + @property + @override + def base_url(self) -> _httpx.URL: + if base_url is not None: + return _httpx.URL(base_url) + + return super().base_url + + @base_url.setter + def base_url(self, url: _httpx.URL | str) -> None: + super().base_url = url # type: ignore[misc] + + @property # type: ignore + @override + def timeout(self) -> float | Timeout | None: + return timeout + + @timeout.setter # type: ignore + def timeout(self, value: float | Timeout | None) -> None: # type: ignore + global timeout + + timeout = value + + @property # type: ignore + @override + def max_retries(self) -> int: + return max_retries + + @max_retries.setter # type: ignore + def max_retries(self, value: int) -> None: # type: ignore + global max_retries + + max_retries = value + + @property # type: ignore + @override + def _custom_headers(self) -> _t.Mapping[str, str] | None: + return default_headers + + @_custom_headers.setter # type: ignore + def _custom_headers(self, value: _t.Mapping[str, str] | None) -> None: # type: ignore + global default_headers + + default_headers = value + + @property # type: ignore + @override + def _custom_query(self) -> _t.Mapping[str, object] | None: + return default_query + + @_custom_query.setter # type: ignore + def _custom_query(self, value: _t.Mapping[str, object] | None) -> None: # type: ignore + global default_query + + default_query = value + + @property # type: ignore + @override + def _client(self) -> _httpx.Client: + return http_client or super()._client + + @_client.setter # type: ignore + def _client(self, value: _httpx.Client) -> None: # type: ignore + global http_client + + http_client = value + + +_client: ReplicateClient | None = None + + +def _load_client() -> ReplicateClient: # type: ignore[reportUnusedFunction] + global _client + + if _client is None: + _client = _ModuleClient( + bearer_token=bearer_token, + base_url=base_url, + timeout=timeout, + max_retries=max_retries, + default_headers=default_headers, + default_query=default_query, + http_client=http_client, + ) + return _client + + return _client + + +def _reset_client() -> None: # type: ignore[reportUnusedFunction] + global _client + + _client = None + + +from ._module_client import ( + models as models, + accounts as accounts, + hardware as hardware, + webhooks as webhooks, + trainings as trainings, + collections as collections, + deployments as deployments, + predictions as predictions, +) diff --git a/src/replicate/_base_client.py b/src/replicate/_base_client.py index 5cbd218..d55fecf 100644 --- a/src/replicate/_base_client.py +++ b/src/replicate/_base_client.py @@ -409,7 +409,8 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0 idempotency_header = self._idempotency_header if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers: - headers[idempotency_header] = options.idempotency_key or self._idempotency_key() + options.idempotency_key = options.idempotency_key or self._idempotency_key() + headers[idempotency_header] = options.idempotency_key # Don't set these headers if they were already set or removed by the caller. We check # `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case. @@ -943,6 +944,10 @@ def _request( request = self._build_request(options, retries_taken=retries_taken) self._prepare_request(request) + if options.idempotency_key: + # ensure the idempotency key is reused between requests + input_options.idempotency_key = options.idempotency_key + kwargs: HttpxSendArgs = {} if self.custom_auth is not None: kwargs["auth"] = self.custom_auth @@ -1475,6 +1480,10 @@ async def _request( request = self._build_request(options, retries_taken=retries_taken) await self._prepare_request(request) + if options.idempotency_key: + # ensure the idempotency key is reused between requests + input_options.idempotency_key = options.idempotency_key + kwargs: HttpxSendArgs = {} if self.custom_auth is not None: kwargs["auth"] = self.custom_auth diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py new file mode 100644 index 0000000..4e2ca08 --- /dev/null +++ b/src/replicate/_module_client.py @@ -0,0 +1,64 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import override + +from . import resources, _load_client +from ._utils import LazyProxy + + +class ModelsResourceProxy(LazyProxy[resources.ModelsResource]): + @override + def __load__(self) -> resources.ModelsResource: + return _load_client().models + + +class HardwareResourceProxy(LazyProxy[resources.HardwareResource]): + @override + def __load__(self) -> resources.HardwareResource: + return _load_client().hardware + + +class AccountsResourceProxy(LazyProxy[resources.AccountsResource]): + @override + def __load__(self) -> resources.AccountsResource: + return _load_client().accounts + + +class WebhooksResourceProxy(LazyProxy[resources.WebhooksResource]): + @override + def __load__(self) -> resources.WebhooksResource: + return _load_client().webhooks + + +class TrainingsResourceProxy(LazyProxy[resources.TrainingsResource]): + @override + def __load__(self) -> resources.TrainingsResource: + return _load_client().trainings + + +class CollectionsResourceProxy(LazyProxy[resources.CollectionsResource]): + @override + def __load__(self) -> resources.CollectionsResource: + return _load_client().collections + + +class DeploymentsResourceProxy(LazyProxy[resources.DeploymentsResource]): + @override + def __load__(self) -> resources.DeploymentsResource: + return _load_client().deployments + + +class PredictionsResourceProxy(LazyProxy[resources.PredictionsResource]): + @override + def __load__(self) -> resources.PredictionsResource: + return _load_client().predictions + + +models: resources.ModelsResource = ModelsResourceProxy().__as_proxied__() +hardware: resources.HardwareResource = HardwareResourceProxy().__as_proxied__() +accounts: resources.AccountsResource = AccountsResourceProxy().__as_proxied__() +webhooks: resources.WebhooksResource = WebhooksResourceProxy().__as_proxied__() +trainings: resources.TrainingsResource = TrainingsResourceProxy().__as_proxied__() +collections: resources.CollectionsResource = CollectionsResourceProxy().__as_proxied__() +deployments: resources.DeploymentsResource = DeploymentsResourceProxy().__as_proxied__() +predictions: resources.PredictionsResource = PredictionsResourceProxy().__as_proxied__() diff --git a/src/replicate/_utils/_transform.py b/src/replicate/_utils/_transform.py index 7ac2e17..b0cc20a 100644 --- a/src/replicate/_utils/_transform.py +++ b/src/replicate/_utils/_transform.py @@ -5,13 +5,15 @@ import pathlib from typing import Any, Mapping, TypeVar, cast from datetime import date, datetime -from typing_extensions import Literal, get_args, override, get_type_hints +from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints import anyio import pydantic from ._utils import ( is_list, + is_given, + lru_cache, is_mapping, is_iterable, ) @@ -108,6 +110,7 @@ class Params(TypedDict, total=False): return cast(_T, transformed) +@lru_cache(maxsize=8096) def _get_annotated_type(type_: type) -> type | None: """If the given type is an `Annotated` type then it is returned, if not `None` is returned. @@ -142,6 +145,10 @@ def _maybe_transform_key(key: str, type_: type) -> str: return key +def _no_transform_needed(annotation: type) -> bool: + return annotation == float or annotation == int + + def _transform_recursive( data: object, *, @@ -184,6 +191,15 @@ def _transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -245,6 +261,11 @@ def _transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include `NotGiven` values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -332,6 +353,15 @@ async def _async_transform_recursive( return cast(object, data) inner_type = extract_type_arg(stripped_type, 0) + if _no_transform_needed(inner_type): + # for some types there is no need to transform anything, so we can get a small + # perf boost from skipping that work. + # + # but we still need to convert to a list to ensure the data is json-serializable + if is_list(data): + return data + return list(data) + return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] if is_union_type(stripped_type): @@ -393,6 +423,11 @@ async def _async_transform_typeddict( result: dict[str, object] = {} annotations = get_type_hints(expected_type, include_extras=True) for key, value in data.items(): + if not is_given(value): + # we don't need to include `NotGiven` values here as they'll + # be stripped out before the request is sent anyway + continue + type_ = annotations.get(key) if type_ is None: # we do not have a type annotation for this field, leave it as is @@ -400,3 +435,13 @@ async def _async_transform_typeddict( else: result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_) return result + + +@lru_cache(maxsize=8096) +def get_type_hints( + obj: Any, + globalns: dict[str, Any] | None = None, + localns: Mapping[str, Any] | None = None, + include_extras: bool = False, +) -> dict[str, Any]: + return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras) diff --git a/src/replicate/_utils/_typing.py b/src/replicate/_utils/_typing.py index 278749b..1958820 100644 --- a/src/replicate/_utils/_typing.py +++ b/src/replicate/_utils/_typing.py @@ -13,6 +13,7 @@ get_origin, ) +from ._utils import lru_cache from .._types import InheritsGeneric from .._compat import is_union as _is_union @@ -66,6 +67,7 @@ def is_type_alias_type(tp: Any, /) -> TypeIs[typing_extensions.TypeAliasType]: # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] +@lru_cache(maxsize=8096) def strip_annotated_type(typ: type) -> type: if is_required_type(typ) or is_annotated_type(typ): return strip_annotated_type(cast(type, get_args(typ)[0])) diff --git a/tests/test_client.py b/tests/test_client.py index 90a05ca..bd5b999 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1639,7 +1639,7 @@ def test_get_platform(self) -> None: import threading from replicate._utils import asyncify - from replicate._base_client import get_platform + from replicate._base_client import get_platform async def test_main() -> None: result = await asyncify(get_platform)() diff --git a/tests/test_module_client.py b/tests/test_module_client.py new file mode 100644 index 0000000..327a9a0 --- /dev/null +++ b/tests/test_module_client.py @@ -0,0 +1,87 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import httpx +import pytest +from httpx import URL + +import replicate +from replicate import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES + + +def reset_state() -> None: + replicate._reset_client() + replicate.bearer_token = None or "My Bearer Token" + replicate.base_url = None + replicate.timeout = DEFAULT_TIMEOUT + replicate.max_retries = DEFAULT_MAX_RETRIES + replicate.default_headers = None + replicate.default_query = None + replicate.http_client = None + + +@pytest.fixture(autouse=True) +def reset_state_fixture() -> None: + reset_state() + + +def test_base_url_option() -> None: + assert replicate.base_url is None + assert replicate.collections._client.base_url == URL("https://api.replicate.com/v1/") + + replicate.base_url = "http://foo.com" + + assert replicate.base_url == URL("http://foo.com") + assert replicate.collections._client.base_url == URL("http://foo.com") + + +def test_timeout_option() -> None: + assert replicate.timeout == replicate.DEFAULT_TIMEOUT + assert replicate.collections._client.timeout == replicate.DEFAULT_TIMEOUT + + replicate.timeout = 3 + + assert replicate.timeout == 3 + assert replicate.collections._client.timeout == 3 + + +def test_max_retries_option() -> None: + assert replicate.max_retries == replicate.DEFAULT_MAX_RETRIES + assert replicate.collections._client.max_retries == replicate.DEFAULT_MAX_RETRIES + + replicate.max_retries = 1 + + assert replicate.max_retries == 1 + assert replicate.collections._client.max_retries == 1 + + +def test_default_headers_option() -> None: + assert replicate.default_headers == None + + replicate.default_headers = {"Foo": "Bar"} + + assert replicate.default_headers["Foo"] == "Bar" + assert replicate.collections._client.default_headers["Foo"] == "Bar" + + +def test_default_query_option() -> None: + assert replicate.default_query is None + assert replicate.collections._client._custom_query == {} + + replicate.default_query = {"Foo": {"nested": 1}} + + assert replicate.default_query["Foo"] == {"nested": 1} + assert replicate.collections._client._custom_query["Foo"] == {"nested": 1} + + +def test_http_client_option() -> None: + assert replicate.http_client is None + + original_http_client = replicate.collections._client._client + assert original_http_client is not None + + new_client = httpx.Client() + replicate.http_client = new_client + + assert replicate.collections._client._client is new_client diff --git a/tests/test_transform.py b/tests/test_transform.py index d75c59c..bb44092 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -8,7 +8,7 @@ import pytest -from replicate._types import Base64FileInput +from replicate._types import NOT_GIVEN, Base64FileInput from replicate._utils import ( PropertyInfo, transform as _transform, @@ -432,3 +432,22 @@ async def test_base64_file_input(use_async: bool) -> None: assert await transform({"foo": io.BytesIO(b"Hello, world!")}, TypedDictBase64Input, use_async) == { "foo": "SGVsbG8sIHdvcmxkIQ==" } # type: ignore[comparison-overlap] + + +@parametrize +@pytest.mark.asyncio +async def test_transform_skipping(use_async: bool) -> None: + # lists of ints are left as-is + data = [1, 2, 3] + assert await transform(data, List[int], use_async) is data + + # iterables of ints are converted to a list + data = iter([1, 2, 3]) + assert await transform(data, Iterable[int], use_async) == [1, 2, 3] + + +@parametrize +@pytest.mark.asyncio +async def test_strips_notgiven(use_async: bool) -> None: + assert await transform({"foo_bar": "bar"}, Foo1, use_async) == {"fooBar": "bar"} + assert await transform({"foo_bar": NOT_GIVEN}, Foo1, use_async) == {}