Skip to content
10 changes: 7 additions & 3 deletions src/replicate/_module_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,17 @@ def _run(*args, **kwargs):
return _load_client().run(*args, **kwargs)

def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
from .lib._predictions_use import use

if use_async:
# For async, we need to use AsyncReplicate instead
from ._client import AsyncReplicate

client = AsyncReplicate()
return client.use(ref, hint=hint, streaming=streaming, **kwargs)
return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs)
return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs)

from ._client import Replicate

return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)

run = _run
use = _use
Expand Down
31 changes: 19 additions & 12 deletions src/replicate/lib/_predictions_use.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
Dict,
List,
Type,
Tuple,
Union,
Generic,
Expand Down Expand Up @@ -436,15 +437,18 @@ class Function(Generic[Input, Output]):
A wrapper for a Replicate model that can be called as a function.
"""

_client: Client
_ref: str
_streaming: bool

def __init__(self, client: Client, ref: str, *, streaming: bool) -> None:
self._client = client
def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None:
self._client_class = client
self._ref = ref
self._streaming = streaming

@property
def _client(self) -> Client:
return self._client_class()

def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
return self.create(*args, **inputs).output()

Expand Down Expand Up @@ -666,16 +670,19 @@ class AsyncFunction(Generic[Input, Output]):
An async wrapper for a Replicate model that can be called as a function.
"""

_client: AsyncClient
_ref: str
_streaming: bool
_openapi_schema: Optional[Dict[str, Any]] = None

def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None:
self._client = client
def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None:
self._client_class = client
self._ref = ref
self._streaming = streaming

@property
def _client(self) -> AsyncClient:
return self._client_class()

@cached_property
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
return ModelVersionIdentifier.parse(self._ref)
Expand Down Expand Up @@ -804,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]:

@overload
def use(
client: Client,
client: Type[Client],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -814,7 +821,7 @@ def use(

@overload
def use(
client: Client,
client: Type[Client],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -824,7 +831,7 @@ def use(

@overload
def use(
client: AsyncClient,
client: Type[AsyncClient],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -834,7 +841,7 @@ def use(

@overload
def use(
client: AsyncClient,
client: Type[AsyncClient],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None,
Expand All @@ -843,7 +850,7 @@ def use(


def use(
client: Union[Client, AsyncClient],
client: Union[Type[Client], Type[AsyncClient]],
ref: Union[str, FunctionRef[Input, Output]],
*,
hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
Expand All @@ -868,7 +875,7 @@ def use(
except AttributeError:
pass

if isinstance(client, AsyncClient):
if issubclass(client, AsyncClient):
# TODO: Fix type inference for AsyncFunction return type
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value]

Expand Down
52 changes: 52 additions & 0 deletions tests/test_simple_lazy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Test lazy client creation in replicate.use()."""

import os
import sys
from unittest.mock import MagicMock, patch


def test_use_does_not_raise_without_token():
"""Test that replicate.use() works even when no API token is available."""
sys.path.insert(0, "src")

with patch.dict(os.environ, {}, clear=True):
with patch.dict(sys.modules, {"cog": None}):
import replicate

# Should not raise an exception
model = replicate.use("test/model") # type: ignore[misc]
assert model is not None


def test_cog_current_scope():
"""Test that cog.current_scope().context is read on each client creation."""
sys.path.insert(0, "src")

mock_context = MagicMock()
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-1")]

mock_scope = MagicMock()
mock_scope.context = mock_context

mock_cog = MagicMock()
mock_cog.current_scope.return_value = mock_scope

with patch.dict(os.environ, {}, clear=True):
with patch.dict(sys.modules, {"cog": mock_cog}):
import replicate

model = replicate.use("test/model") # type: ignore[misc]

# Access the client property - this should trigger client creation and cog.current_scope call
_ = model._client

assert mock_cog.current_scope.call_count == 1

# Change the token and access client again - should trigger another call
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")]

# Create a new model to trigger another client creation
model2 = replicate.use("test/model2") # type: ignore[misc]
_ = model2._client
Comment on lines +40 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor tweak, we want to assert that the same model refreshes the client between calls. If we're looking at the client we can look at the bearer_token to verify it has actually picked up the token correctly.

Suggested change
# Access the client property - this should trigger client creation and cog.current_scope call
_ = model._client
assert mock_cog.current_scope.call_count == 1
# Change the token and access client again - should trigger another call
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")]
# Create a new model to trigger another client creation
model2 = replicate.use("test/model2") # type: ignore[misc]
_ = model2._client
# Access the client property - this should trigger client creation and cog.current_scope call
assert model._client.bearer_token == "test-token-1"
assert mock_cog.current_scope.call_count == 1
# Change the token and access client again - should trigger another call
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")]
# Assert that the second time picks up the new token
assert model2._client.bearer_token == "test-token-2"


assert mock_cog.current_scope.call_count == 2
Loading