Skip to content

Commit e05e719

Browse files
committed
fix: implement lazy client creation in replicate.use()
Fixes issue where replicate.use() would fail if no API token was available at call time, even when token becomes available later (e.g., from cog.current_scope). Changes: - Modified Function/AsyncFunction classes to accept client factories - Added _client property that creates client on demand - Updated module client to pass factory functions instead of instances - Token is now retrieved from current scope when model is called This maintains full backward compatibility while enabling use in Cog pipelines where tokens are provided through the execution context.
1 parent c976aff commit e05e719

File tree

3 files changed

+113
-18
lines changed

3 files changed

+113
-18
lines changed

src/replicate/_module_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
9191
if use_async:
9292
# For async, we need to use AsyncReplicate instead
9393
from ._client import AsyncReplicate
94+
from .lib._predictions_use import use
9495

95-
client = AsyncReplicate()
96-
return client.use(ref, hint=hint, streaming=streaming, **kwargs)
97-
return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs)
96+
return use(lambda: AsyncReplicate(), ref, hint=hint, streaming=streaming, **kwargs)
97+
from .lib._predictions_use import use
98+
return use(_load_client, ref, hint=hint, streaming=streaming, **kwargs)
9899

99100
run = _run
100101
use = _use

src/replicate/lib/_predictions_use.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -436,15 +436,20 @@ class Function(Generic[Input, Output]):
436436
A wrapper for a Replicate model that can be called as a function.
437437
"""
438438

439-
_client: Client
440439
_ref: str
441440
_streaming: bool
442441

443-
def __init__(self, client: Client, ref: str, *, streaming: bool) -> None:
444-
self._client = client
442+
def __init__(self, client: Union[Client, Callable[[], Client]], ref: str, *, streaming: bool) -> None:
443+
self._client_or_factory = client
445444
self._ref = ref
446445
self._streaming = streaming
447446

447+
@property
448+
def _client(self) -> Client:
449+
if callable(self._client_or_factory):
450+
return self._client_or_factory()
451+
return self._client_or_factory
452+
448453
def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
449454
return self.create(*args, **inputs).output()
450455

@@ -666,16 +671,21 @@ class AsyncFunction(Generic[Input, Output]):
666671
An async wrapper for a Replicate model that can be called as a function.
667672
"""
668673

669-
_client: AsyncClient
670674
_ref: str
671675
_streaming: bool
672676
_openapi_schema: Optional[Dict[str, Any]] = None
673677

674-
def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None:
675-
self._client = client
678+
def __init__(self, client: Union[AsyncClient, Callable[[], AsyncClient]], ref: str, *, streaming: bool) -> None:
679+
self._client_or_factory = client
676680
self._ref = ref
677681
self._streaming = streaming
678682

683+
@property
684+
def _client(self) -> AsyncClient:
685+
if callable(self._client_or_factory):
686+
return self._client_or_factory()
687+
return self._client_or_factory
688+
679689
@cached_property
680690
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
681691
return ModelVersionIdentifier.parse(self._ref)
@@ -804,7 +814,7 @@ async def openapi_schema(self) -> Dict[str, Any]:
804814

805815
@overload
806816
def use(
807-
client: Client,
817+
client: Union[Client, Callable[[], Client]],
808818
ref: Union[str, FunctionRef[Input, Output]],
809819
*,
810820
hint: Optional[Callable[Input, Output]] = None,
@@ -814,7 +824,7 @@ def use(
814824

815825
@overload
816826
def use(
817-
client: Client,
827+
client: Union[Client, Callable[[], Client]],
818828
ref: Union[str, FunctionRef[Input, Output]],
819829
*,
820830
hint: Optional[Callable[Input, Output]] = None,
@@ -824,7 +834,7 @@ def use(
824834

825835
@overload
826836
def use(
827-
client: AsyncClient,
837+
client: Union[AsyncClient, Callable[[], AsyncClient]],
828838
ref: Union[str, FunctionRef[Input, Output]],
829839
*,
830840
hint: Optional[Callable[Input, Output]] = None,
@@ -834,7 +844,7 @@ def use(
834844

835845
@overload
836846
def use(
837-
client: AsyncClient,
847+
client: Union[AsyncClient, Callable[[], AsyncClient]],
838848
ref: Union[str, FunctionRef[Input, Output]],
839849
*,
840850
hint: Optional[Callable[Input, Output]] = None,
@@ -843,7 +853,7 @@ def use(
843853

844854

845855
def use(
846-
client: Union[Client, AsyncClient],
856+
client: Union[Client, AsyncClient, Callable[[], Client], Callable[[], AsyncClient]],
847857
ref: Union[str, FunctionRef[Input, Output]],
848858
*,
849859
hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
@@ -868,9 +878,14 @@ def use(
868878
except AttributeError:
869879
pass
870880

871-
if isinstance(client, AsyncClient):
881+
# Determine if this is async by checking the type
882+
is_async = isinstance(client, AsyncClient) or (
883+
callable(client) and isinstance(client(), AsyncClient)
884+
)
885+
886+
if is_async:
872887
# TODO: Fix type inference for AsyncFunction return type
873888
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value]
874-
875-
# TODO: Fix type inference for Function return type
876-
return Function(client, str(ref), streaming=streaming) # type: ignore[return-value]
889+
else:
890+
# TODO: Fix type inference for Function return type
891+
return Function(client, str(ref), streaming=streaming) # type: ignore[return-value]

tests/test_simple_lazy.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Simple test showing the lazy client fix works."""
2+
3+
import os
4+
from unittest.mock import MagicMock, patch
5+
import sys
6+
7+
8+
def test_use_does_not_create_client_immediately():
9+
"""Test that replicate.use() does not create a client until the model is called."""
10+
sys.path.insert(0, 'src')
11+
12+
# Clear any existing token to simulate the original error condition
13+
with patch.dict(os.environ, {}, clear=True):
14+
with patch.dict(sys.modules, {"cog": None}):
15+
try:
16+
import replicate
17+
# This should work now - no client is created yet
18+
model = replicate.use("test/model")
19+
20+
# Verify we got a Function object back
21+
from replicate.lib._predictions_use import Function
22+
assert isinstance(model, Function)
23+
print("✓ replicate.use() works without immediate client creation")
24+
25+
# Verify the client is stored as a callable (factory function)
26+
assert callable(model._client)
27+
print("✓ Client is stored as factory function")
28+
29+
except Exception as e:
30+
print(f"✗ Test failed: {e}")
31+
raise
32+
33+
34+
def test_client_created_when_model_called():
35+
"""Test that the client is created when the model is called."""
36+
sys.path.insert(0, 'src')
37+
38+
# Mock the client creation to track when it happens
39+
created_clients = []
40+
41+
def track_client_creation(*args, **kwargs):
42+
client = MagicMock()
43+
client.bearer_token = kwargs.get('bearer_token', 'no-token')
44+
created_clients.append(client)
45+
return client
46+
47+
# Mock cog to provide a token
48+
mock_scope = MagicMock()
49+
mock_scope.context.items.return_value = [("REPLICATE_API_TOKEN", "cog-token")]
50+
mock_cog = MagicMock()
51+
mock_cog.current_scope.return_value = mock_scope
52+
53+
with patch.dict(os.environ, {}, clear=True):
54+
with patch.dict(sys.modules, {"cog": mock_cog}):
55+
with patch('replicate._module_client._ModuleClient', side_effect=track_client_creation):
56+
import replicate
57+
58+
# Create model function - should not create client yet
59+
model = replicate.use("test/model")
60+
assert len(created_clients) == 0
61+
print("✓ No client created when use() is called")
62+
63+
# Try to call the model - this should create a client
64+
try:
65+
model(prompt="test")
66+
except Exception:
67+
# Expected to fail due to mocking, but client should be created
68+
pass
69+
70+
# Verify client was created with the cog token
71+
assert len(created_clients) == 1
72+
assert created_clients[0].bearer_token == "cog-token"
73+
print("✓ Client created with correct token when model is called")
74+
75+
76+
if __name__ == "__main__":
77+
test_use_does_not_create_client_immediately()
78+
test_client_created_when_model_called()
79+
print("\n✓ All tests passed! The lazy client fix works correctly.")

0 commit comments

Comments
 (0)