Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .release-please-manifest.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
".": "2.0.0-alpha.21"
".": "2.0.0-alpha.22"
}
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## 2.0.0-alpha.22 (2025-08-28)

Full Changelog: [v2.0.0-alpha.21...v2.0.0-alpha.22](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.21...v2.0.0-alpha.22)

### Bug Fixes

* implement lazy client creation in replicate.use() ([#57](https://github.com/replicate/replicate-python-stainless/issues/57)) ([caf4c4e](https://github.com/replicate/replicate-python-stainless/commit/caf4c4efa2be271144b22b93a38ea490b10ad86b))

## 2.0.0-alpha.21 (2025-08-26)

Full Changelog: [v2.0.0-alpha.20...v2.0.0-alpha.21](https://github.com/replicate/replicate-python-stainless/compare/v2.0.0-alpha.20...v2.0.0-alpha.21)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "replicate"
version = "2.0.0-alpha.21"
version = "2.0.0-alpha.22"
description = "The official Python library for the replicate API"
dynamic = ["readme"]
license = "Apache-2.0"
Expand Down
10 changes: 7 additions & 3 deletions src/replicate/_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ def _run(*args, **kwargs):
return _load_client().run(*args, **kwargs)

def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
from .lib._predictions_use import use

if use_async:
# For async, we need to use AsyncReplicate instead
from ._client import AsyncReplicate

client = AsyncReplicate()
return client.use(ref, hint=hint, streaming=streaming, **kwargs)
return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs)
return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs)

from ._client import Replicate

return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)

run = _run
use = _use
Expand Down
2 changes: 1 addition & 1 deletion src/replicate/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

__title__ = "replicate"
__version__ = "2.0.0-alpha.21" # x-release-please-version
__version__ = "2.0.0-alpha.22" # x-release-please-version
31 changes: 19 additions & 12 deletions src/replicate/lib/_predictions_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Dict,
List,
Type,
Tuple,
Union,
Generic,
Expand Down Expand Up @@ -436,15 +437,18 @@ class Function(Generic[Input, Output]):
A wrapper for a Replicate model that can be called as a function.
"""

_client: Client
_ref: str
_streaming: bool

def __init__(self, client: Client, ref: str, *, streaming: bool) -> None:
self._client = client
def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None:
self._client_class = client
self._ref = ref
self._streaming = streaming

@property
def _client(self) -> Client:
return self._client_class()

def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
return self.create(*args, **inputs).output()

Expand Down Expand Up @@ -666,16 +670,19 @@ class AsyncFunction(Generic[Input, Output]):
An async wrapper for a Replicate model that can be called as a function.
"""

_client: AsyncClient
_ref: str
_streaming: bool
_openapi_schema: Optional[Dict[str, Any]] = None

def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None:
self._client = client
def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None:
self._client_class = client
self._ref = ref
self._streaming = streaming

@property
def _client(self) -> AsyncClient:
return self._client_class()

@cached_property
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
return ModelVersionIdentifier.parse(self._ref)
Expand Down Expand Up @@ -804,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]:

@overload
def use(
client: Client,
client: Type[Client],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -814,7 +821,7 @@ def use(

@overload
def use(
client: Client,
client: Type[Client],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -824,7 +831,7 @@ def use(

@overload
def use(
client: AsyncClient,
client: Type[AsyncClient],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -834,7 +841,7 @@ def use(

@overload
def use(
client: AsyncClient,
client: Type[AsyncClient],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -843,7 +850,7 @@ def use(


def use(
client: Union[Client, AsyncClient],
client: Union[Type[Client], Type[AsyncClient]],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
Expand All @@ -868,7 +875,7 @@ def use(
except AttributeError:
pass

if isinstance(client, AsyncClient):
if issubclass(client, AsyncClient):
# TODO: Fix type inference for AsyncFunction return type
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value]

Expand Down
52 changes: 52 additions & 0 deletions tests/test_simple_lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Test lazy client creation in replicate.use()."""

import os
import sys
from unittest.mock import MagicMock, patch


def test_use_does_not_raise_without_token():
"""Test that replicate.use() works even when no API token is available."""
sys.path.insert(0, "src")

with patch.dict(os.environ, {}, clear=True):
with patch.dict(sys.modules, {"cog": None}):
import replicate

# Should not raise an exception
model = replicate.use("test/model") # type: ignore[misc]
assert model is not None


def test_cog_current_scope():
"""Test that cog.current_scope().context is read on each client creation."""
sys.path.insert(0, "src")

mock_context = MagicMock()
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-1")]

mock_scope = MagicMock()
mock_scope.context = mock_context

mock_cog = MagicMock()
mock_cog.current_scope.return_value = mock_scope

with patch.dict(os.environ, {}, clear=True):
with patch.dict(sys.modules, {"cog": mock_cog}):
import replicate

model = replicate.use("test/model") # type: ignore[misc]

# Access the client property - this should trigger client creation and cog.current_scope call
_ = model._client

assert mock_cog.current_scope.call_count == 1

# Change the token and access client again - should trigger another call
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")]

# Create a new model to trigger another client creation
model2 = replicate.use("test/model2") # type: ignore[misc]
_ = model2._client

assert mock_cog.current_scope.call_count == 2