Skip to content

Commit caf4c4e

Browse files
zekearon
andauthored
fix: implement lazy client creation in replicate.use() (#57)
* fix: implement lazy client creation in replicate.use() Fixes issue where replicate.use() would fail if no API token was available at call time, even when token becomes available later (e.g., from cog.current_scope). Changes: - Modified Function/AsyncFunction classes to accept client factories - Added _client property that creates client on demand - Updated module client to pass factory functions instead of instances - Token is now retrieved from current scope when model is called This maintains full backward compatibility while enabling use in Cog pipelines where tokens are provided through the execution context. * style: fix linter issues - Remove unused *args parameter in test function - Fix formatting issues from linter * fix: resolve async detection and test issues - Fix async detection to not call client factory prematurely - Add use_async parameter to explicitly indicate async mode - Update test to avoid creating client during verification - Fix test mocking to use correct module path * test: simplify lazy client test Replace complex mocking test with simpler verification that: - use() works without token initially - Lazy client factory is properly configured - Client can be created when needed This avoids complex mocking while still verifying the core functionality. * lint * fix: add type ignore for final linter warning * fix: add arg-type ignore for type checker warnings * refactor: simplify lazy client creation to use Type[Client] only Address PR feedback by removing Union types and using a single consistent approach: - Change Function/AsyncFunction constructors to accept Type[Client] only - Remove Union[Client, Type[Client]] in favor of just Type[Client] - Simplify _client property logic by removing isinstance checks - Update all use() overloads to accept class types only - Use issubclass() for async client detection instead of complex logic - Update tests to check for _client_class attribute This maintains the same lazy client creation behavior while being much simpler and more consistent. * Update tests/test_simple_lazy.py Co-authored-by: Aron Carroll <[email protected]> * test: improve lazy client test to follow project conventions - Remove verbose comments and print statements - Focus on observable behavior rather than internal implementation - Use proper mocking that matches actual cog integration - Test that cog.current_scope() is called on client creation - Address code review feedback from PR discussion * lint * lint --------- Co-authored-by: Aron Carroll <[email protected]>
1 parent af6ae70 commit caf4c4e

File tree

3 files changed

+78
-15
lines changed

3 files changed

+78
-15
lines changed

src/replicate/_module_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,17 @@ def _run(*args, **kwargs):
8888
return _load_client().run(*args, **kwargs)
8989

9090
def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
91+
from .lib._predictions_use import use
92+
9193
if use_async:
9294
# For async, we need to use AsyncReplicate instead
9395
from ._client import AsyncReplicate
9496

95-
client = AsyncReplicate()
96-
return client.use(ref, hint=hint, streaming=streaming, **kwargs)
97-
return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs)
97+
return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs)
98+
99+
from ._client import Replicate
100+
101+
return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)
98102

99103
run = _run
100104
use = _use

src/replicate/lib/_predictions_use.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Any,
1010
Dict,
1111
List,
12+
Type,
1213
Tuple,
1314
Union,
1415
Generic,
@@ -436,15 +437,18 @@ class Function(Generic[Input, Output]):
436437
A wrapper for a Replicate model that can be called as a function.
437438
"""
438439

439-
_client: Client
440440
_ref: str
441441
_streaming: bool
442442

443-
def __init__(self, client: Client, ref: str, *, streaming: bool) -> None:
444-
self._client = client
443+
def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None:
444+
self._client_class = client
445445
self._ref = ref
446446
self._streaming = streaming
447447

448+
@property
449+
def _client(self) -> Client:
450+
return self._client_class()
451+
448452
def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
449453
return self.create(*args, **inputs).output()
450454

@@ -666,16 +670,19 @@ class AsyncFunction(Generic[Input, Output]):
666670
An async wrapper for a Replicate model that can be called as a function.
667671
"""
668672

669-
_client: AsyncClient
670673
_ref: str
671674
_streaming: bool
672675
_openapi_schema: Optional[Dict[str, Any]] = None
673676

674-
def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None:
675-
self._client = client
677+
def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None:
678+
self._client_class = client
676679
self._ref = ref
677680
self._streaming = streaming
678681

682+
@property
683+
def _client(self) -> AsyncClient:
684+
return self._client_class()
685+
679686
@cached_property
680687
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
681688
return ModelVersionIdentifier.parse(self._ref)
@@ -804,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]:
804811

805812
@overload
806813
def use(
807-
client: Client,
814+
client: Type[Client],
808815
ref: Union[str, FunctionRef[Input, Output]],
809816
*,
810817
hint: Optional[Callable[Input, Output]] = None,
@@ -814,7 +821,7 @@ def use(
814821

815822
@overload
816823
def use(
817-
client: Client,
824+
client: Type[Client],
818825
ref: Union[str, FunctionRef[Input, Output]],
819826
*,
820827
hint: Optional[Callable[Input, Output]] = None,
@@ -824,7 +831,7 @@ def use(
824831

825832
@overload
826833
def use(
827-
client: AsyncClient,
834+
client: Type[AsyncClient],
828835
ref: Union[str, FunctionRef[Input, Output]],
829836
*,
830837
hint: Optional[Callable[Input, Output]] = None,
@@ -834,7 +841,7 @@ def use(
834841

835842
@overload
836843
def use(
837-
client: AsyncClient,
844+
client: Type[AsyncClient],
838845
ref: Union[str, FunctionRef[Input, Output]],
839846
*,
840847
hint: Optional[Callable[Input, Output]] = None,
@@ -843,7 +850,7 @@ def use(
843850

844851

845852
def use(
846-
client: Union[Client, AsyncClient],
853+
client: Union[Type[Client], Type[AsyncClient]],
847854
ref: Union[str, FunctionRef[Input, Output]],
848855
*,
849856
hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
@@ -868,7 +875,7 @@ def use(
868875
except AttributeError:
869876
pass
870877

871-
if isinstance(client, AsyncClient):
878+
if issubclass(client, AsyncClient):
872879
# TODO: Fix type inference for AsyncFunction return type
873880
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value]
874881

tests/test_simple_lazy.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Test lazy client creation in replicate.use()."""
2+
3+
import os
4+
import sys
5+
from unittest.mock import MagicMock, patch
6+
7+
8+
def test_use_does_not_raise_without_token():
9+
"""Test that replicate.use() works even when no API token is available."""
10+
sys.path.insert(0, "src")
11+
12+
with patch.dict(os.environ, {}, clear=True):
13+
with patch.dict(sys.modules, {"cog": None}):
14+
import replicate
15+
16+
# Should not raise an exception
17+
model = replicate.use("test/model") # type: ignore[misc]
18+
assert model is not None
19+
20+
21+
def test_cog_current_scope():
22+
"""Test that cog.current_scope().context is read on each client creation."""
23+
sys.path.insert(0, "src")
24+
25+
mock_context = MagicMock()
26+
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-1")]
27+
28+
mock_scope = MagicMock()
29+
mock_scope.context = mock_context
30+
31+
mock_cog = MagicMock()
32+
mock_cog.current_scope.return_value = mock_scope
33+
34+
with patch.dict(os.environ, {}, clear=True):
35+
with patch.dict(sys.modules, {"cog": mock_cog}):
36+
import replicate
37+
38+
model = replicate.use("test/model") # type: ignore[misc]
39+
40+
# Access the client property - this should trigger client creation and cog.current_scope call
41+
_ = model._client
42+
43+
assert mock_cog.current_scope.call_count == 1
44+
45+
# Change the token and access client again - should trigger another call
46+
mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")]
47+
48+
# Create a new model to trigger another client creation
49+
model2 = replicate.use("test/model2") # type: ignore[misc]
50+
_ = model2._client
51+
52+
assert mock_cog.current_scope.call_count == 2

0 commit comments

Comments
 (0)