Skip to content

Commit ab20bbf

Browse files
committed
refactor: simplify lazy client creation to use Type[Client] only
Address PR feedback by removing Union types and using a single consistent approach: - Change Function/AsyncFunction constructors to accept Type[Client] only - Remove Union[Client, Type[Client]] in favor of just Type[Client] - Simplify _client property logic by removing isinstance checks - Update all use() overloads to accept class types only - Use issubclass() for async client detection instead of complex logic - Update tests to check for _client_class attribute This maintains the same lazy client creation behavior while being much simpler and more consistent.
1 parent 0ab8c4f commit ab20bbf

File tree

3 files changed

+26
-31
lines changed

3 files changed

+26
-31
lines changed

src/replicate/_module_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,11 @@ def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs):
9494
# For async, we need to use AsyncReplicate instead
9595
from ._client import AsyncReplicate
9696

97-
return use(lambda: AsyncReplicate(), ref, hint=hint, streaming=streaming, use_async=True, **kwargs)
97+
return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs)
9898

99-
return use(_load_client, ref, hint=hint, streaming=streaming, use_async=False, **kwargs)
99+
from ._client import Replicate
100+
101+
return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs)
100102

101103
run = _run
102104
use = _use

src/replicate/lib/_predictions_use.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
Any,
1010
Dict,
1111
List,
12+
Type,
1213
Tuple,
1314
Union,
1415
Generic,
@@ -439,16 +440,14 @@ class Function(Generic[Input, Output]):
439440
_ref: str
440441
_streaming: bool
441442

442-
def __init__(self, client: Union[Client, Callable[[], Client]], ref: str, *, streaming: bool) -> None:
443-
self._client_or_factory = client
443+
def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None:
444+
self._client_class = client
444445
self._ref = ref
445446
self._streaming = streaming
446447

447448
@property
448449
def _client(self) -> Client:
449-
if callable(self._client_or_factory):
450-
return self._client_or_factory()
451-
return self._client_or_factory
450+
return self._client_class()
452451

453452
def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output:
454453
return self.create(*args, **inputs).output()
@@ -675,16 +674,14 @@ class AsyncFunction(Generic[Input, Output]):
675674
_streaming: bool
676675
_openapi_schema: Optional[Dict[str, Any]] = None
677676

678-
def __init__(self, client: Union[AsyncClient, Callable[[], AsyncClient]], ref: str, *, streaming: bool) -> None:
679-
self._client_or_factory = client
677+
def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None:
678+
self._client_class = client
680679
self._ref = ref
681680
self._streaming = streaming
682681

683682
@property
684683
def _client(self) -> AsyncClient:
685-
if callable(self._client_or_factory):
686-
return self._client_or_factory()
687-
return self._client_or_factory
684+
return self._client_class()
688685

689686
@cached_property
690687
def _parsed_ref(self) -> Tuple[str, str, Optional[str]]:
@@ -814,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]:
814811

815812
@overload
816813
def use(
817-
client: Union[Client, Callable[[], Client]],
814+
client: Type[Client],
818815
ref: Union[str, FunctionRef[Input, Output]],
819816
*,
820817
hint: Optional[Callable[Input, Output]] = None,
@@ -824,7 +821,7 @@ def use(
824821

825822
@overload
826823
def use(
827-
client: Union[Client, Callable[[], Client]],
824+
client: Type[Client],
828825
ref: Union[str, FunctionRef[Input, Output]],
829826
*,
830827
hint: Optional[Callable[Input, Output]] = None,
@@ -834,7 +831,7 @@ def use(
834831

835832
@overload
836833
def use(
837-
client: Union[AsyncClient, Callable[[], AsyncClient]],
834+
client: Type[AsyncClient],
838835
ref: Union[str, FunctionRef[Input, Output]],
839836
*,
840837
hint: Optional[Callable[Input, Output]] = None,
@@ -844,7 +841,7 @@ def use(
844841

845842
@overload
846843
def use(
847-
client: Union[AsyncClient, Callable[[], AsyncClient]],
844+
client: Type[AsyncClient],
848845
ref: Union[str, FunctionRef[Input, Output]],
849846
*,
850847
hint: Optional[Callable[Input, Output]] = None,
@@ -853,12 +850,11 @@ def use(
853850

854851

855852
def use(
856-
client: Union[Client, AsyncClient, Callable[[], Client], Callable[[], AsyncClient]],
853+
client: Union[Type[Client], Type[AsyncClient]],
857854
ref: Union[str, FunctionRef[Input, Output]],
858855
*,
859856
hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
860857
streaming: bool = False,
861-
use_async: bool = False, # Internal parameter to indicate async mode
862858
) -> Union[
863859
Function[Input, Output],
864860
AsyncFunction[Input, Output],
@@ -879,12 +875,9 @@ def use(
879875
except AttributeError:
880876
pass
881877

882-
# Determine if this is async
883-
is_async = isinstance(client, AsyncClient) or use_async
884-
885-
if is_async:
878+
if issubclass(client, AsyncClient):
886879
# TODO: Fix type inference for AsyncFunction return type
887-
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value,arg-type]
888-
else:
889-
# TODO: Fix type inference for Function return type
890-
return Function(client, str(ref), streaming=streaming) # type: ignore[return-value,arg-type]
880+
return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value]
881+
882+
# TODO: Fix type inference for Function return type
883+
return Function(client, str(ref), streaming=streaming) # type: ignore[return-value]

tests/test_simple_lazy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def test_use_does_not_create_client_immediately():
2727

2828
# Verify the client property is a property that will create client on demand
2929
# We can't call it without a token, but we can check it's the right type
30-
assert hasattr(model, "_client_or_factory") # type: ignore[misc]
31-
print("✓ Client factory is stored for lazy creation")
30+
assert hasattr(model, "_client_class") # type: ignore[misc]
31+
print("✓ Client class is stored for lazy creation")
3232

3333
except Exception as e:
3434
print(f"✗ Test failed: {e}")
@@ -55,9 +55,9 @@ def test_client_created_when_model_called():
5555
print("✓ Model function created successfully")
5656

5757
# Verify the model has the lazy client setup
58-
assert hasattr(model, "_client_or_factory")
59-
assert callable(model._client_or_factory)
60-
print("✓ Lazy client factory is properly configured")
58+
assert hasattr(model, "_client_class")
59+
assert isinstance(model._client_class, type)
60+
print("✓ Lazy client class is properly configured")
6161

6262
# Test that accessing _client property works (creates client)
6363
try:

0 commit comments

Comments
 (0)