@@ -436,15 +436,20 @@ class Function(Generic[Input, Output]):
436436 A wrapper for a Replicate model that can be called as a function.
437437 """
438438
439- _client : Client
440439 _ref : str
441440 _streaming : bool
442441
443- def __init__ (self , client : Client , ref : str , * , streaming : bool ) -> None :
444- self ._client = client
442+ def __init__ (self , client : Union [ Client , Callable [[], Client ]] , ref : str , * , streaming : bool ) -> None :
443+ self ._client_or_factory = client
445444 self ._ref = ref
446445 self ._streaming = streaming
447446
447+ @property
448+ def _client (self ) -> Client :
449+ if callable (self ._client_or_factory ):
450+ return self ._client_or_factory ()
451+ return self ._client_or_factory
452+
448453 def __call__ (self , * args : Input .args , ** inputs : Input .kwargs ) -> Output :
449454 return self .create (* args , ** inputs ).output ()
450455
@@ -666,16 +671,21 @@ class AsyncFunction(Generic[Input, Output]):
666671 An async wrapper for a Replicate model that can be called as a function.
667672 """
668673
669- _client : AsyncClient
670674 _ref : str
671675 _streaming : bool
672676 _openapi_schema : Optional [Dict [str , Any ]] = None
673677
674- def __init__ (self , client : AsyncClient , ref : str , * , streaming : bool ) -> None :
675- self ._client = client
678+ def __init__ (self , client : Union [ AsyncClient , Callable [[], AsyncClient ]] , ref : str , * , streaming : bool ) -> None :
679+ self ._client_or_factory = client
676680 self ._ref = ref
677681 self ._streaming = streaming
678682
683+ @property
684+ def _client (self ) -> AsyncClient :
685+ if callable (self ._client_or_factory ):
686+ return self ._client_or_factory ()
687+ return self ._client_or_factory
688+
679689 @cached_property
680690 def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
681691 return ModelVersionIdentifier .parse (self ._ref )
@@ -804,7 +814,7 @@ async def openapi_schema(self) -> Dict[str, Any]:
804814
805815@overload
806816def use (
807- client : Client ,
817+ client : Union [ Client , Callable [[], Client ]] ,
808818 ref : Union [str , FunctionRef [Input , Output ]],
809819 * ,
810820 hint : Optional [Callable [Input , Output ]] = None ,
@@ -814,7 +824,7 @@ def use(
814824
815825@overload
816826def use (
817- client : Client ,
827+ client : Union [ Client , Callable [[], Client ]] ,
818828 ref : Union [str , FunctionRef [Input , Output ]],
819829 * ,
820830 hint : Optional [Callable [Input , Output ]] = None ,
@@ -824,7 +834,7 @@ def use(
824834
825835@overload
826836def use (
827- client : AsyncClient ,
837+ client : Union [ AsyncClient , Callable [[], AsyncClient ]] ,
828838 ref : Union [str , FunctionRef [Input , Output ]],
829839 * ,
830840 hint : Optional [Callable [Input , Output ]] = None ,
@@ -834,7 +844,7 @@ def use(
834844
835845@overload
836846def use (
837- client : AsyncClient ,
847+ client : Union [ AsyncClient , Callable [[], AsyncClient ]] ,
838848 ref : Union [str , FunctionRef [Input , Output ]],
839849 * ,
840850 hint : Optional [Callable [Input , Output ]] = None ,
@@ -843,7 +853,7 @@ def use(
843853
844854
845855def use (
846- client : Union [Client , AsyncClient ],
856+ client : Union [Client , AsyncClient , Callable [[], Client ], Callable [[], AsyncClient ] ],
847857 ref : Union [str , FunctionRef [Input , Output ]],
848858 * ,
849859 hint : Optional [Callable [Input , Output ]] = None , # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
@@ -868,9 +878,14 @@ def use(
868878 except AttributeError :
869879 pass
870880
871- if isinstance (client , AsyncClient ):
881+ # Determine if this is async by checking the type
882+ is_async = isinstance (client , AsyncClient ) or (
883+ callable (client ) and isinstance (client (), AsyncClient )
884+ )
885+
886+ if is_async :
872887 # TODO: Fix type inference for AsyncFunction return type
873888 return AsyncFunction (client , str (ref ), streaming = streaming ) # type: ignore[return-value]
874-
875- # TODO: Fix type inference for Function return type
876- return Function (client , str (ref ), streaming = streaming ) # type: ignore[return-value]
889+ else :
890+ # TODO: Fix type inference for Function return type
891+ return Function (client , str (ref ), streaming = streaming ) # type: ignore[return-value]
0 commit comments