99 Any ,
1010 Dict ,
1111 List ,
12+ Type ,
1213 Tuple ,
1314 Union ,
1415 Generic ,
@@ -439,16 +440,14 @@ class Function(Generic[Input, Output]):
439440 _ref : str
440441 _streaming : bool
441442
442- def __init__ (self , client : Union [Client , Callable [[], Client ] ], ref : str , * , streaming : bool ) -> None :
443- self ._client_or_factory = client
443+ def __init__ (self , client : Type [Client ], ref : str , * , streaming : bool ) -> None :
444+ self ._client_class = client
444445 self ._ref = ref
445446 self ._streaming = streaming
446447
447448 @property
448449 def _client (self ) -> Client :
449- if callable (self ._client_or_factory ):
450- return self ._client_or_factory ()
451- return self ._client_or_factory
450+ return self ._client_class ()
452451
453452 def __call__ (self , * args : Input .args , ** inputs : Input .kwargs ) -> Output :
454453 return self .create (* args , ** inputs ).output ()
@@ -675,16 +674,14 @@ class AsyncFunction(Generic[Input, Output]):
675674 _streaming : bool
676675 _openapi_schema : Optional [Dict [str , Any ]] = None
677676
678- def __init__ (self , client : Union [AsyncClient , Callable [[], AsyncClient ] ], ref : str , * , streaming : bool ) -> None :
679- self ._client_or_factory = client
677+ def __init__ (self , client : Type [AsyncClient ], ref : str , * , streaming : bool ) -> None :
678+ self ._client_class = client
680679 self ._ref = ref
681680 self ._streaming = streaming
682681
683682 @property
684683 def _client (self ) -> AsyncClient :
685- if callable (self ._client_or_factory ):
686- return self ._client_or_factory ()
687- return self ._client_or_factory
684+ return self ._client_class ()
688685
689686 @cached_property
690687 def _parsed_ref (self ) -> Tuple [str , str , Optional [str ]]:
@@ -814,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]:
814811
815812@overload
816813def use (
817- client : Union [Client , Callable [[], Client ] ],
814+ client : Type [Client ],
818815 ref : Union [str , FunctionRef [Input , Output ]],
819816 * ,
820817 hint : Optional [Callable [Input , Output ]] = None ,
@@ -824,7 +821,7 @@ def use(
824821
825822@overload
826823def use (
827- client : Union [Client , Callable [[], Client ] ],
824+ client : Type [Client ],
828825 ref : Union [str , FunctionRef [Input , Output ]],
829826 * ,
830827 hint : Optional [Callable [Input , Output ]] = None ,
@@ -834,7 +831,7 @@ def use(
834831
835832@overload
836833def use (
837- client : Union [AsyncClient , Callable [[], AsyncClient ] ],
834+ client : Type [AsyncClient ],
838835 ref : Union [str , FunctionRef [Input , Output ]],
839836 * ,
840837 hint : Optional [Callable [Input , Output ]] = None ,
@@ -844,7 +841,7 @@ def use(
844841
845842@overload
846843def use (
847- client : Union [AsyncClient , Callable [[], AsyncClient ] ],
844+ client : Type [AsyncClient ],
848845 ref : Union [str , FunctionRef [Input , Output ]],
849846 * ,
850847 hint : Optional [Callable [Input , Output ]] = None ,
@@ -853,12 +850,11 @@ def use(
853850
854851
855852def use (
856- client : Union [Client , AsyncClient , Callable [[], Client ], Callable [[], AsyncClient ]],
853+ client : Union [Type [ Client ], Type [ AsyncClient ]],
857854 ref : Union [str , FunctionRef [Input , Output ]],
858855 * ,
859856 hint : Optional [Callable [Input , Output ]] = None , # pylint: disable=unused-argument # noqa: ARG001 # required for type inference
860857 streaming : bool = False ,
861- use_async : bool = False , # Internal parameter to indicate async mode
862858) -> Union [
863859 Function [Input , Output ],
864860 AsyncFunction [Input , Output ],
@@ -879,12 +875,9 @@ def use(
879875 except AttributeError :
880876 pass
881877
882- # Determine if this is async
883- is_async = isinstance (client , AsyncClient ) or use_async
884-
885- if is_async :
878+ if issubclass (client , AsyncClient ):
886879 # TODO: Fix type inference for AsyncFunction return type
887- return AsyncFunction (client , str (ref ), streaming = streaming ) # type: ignore[return-value,arg-type ]
888- else :
889- # TODO: Fix type inference for Function return type
890- return Function (client , str (ref ), streaming = streaming ) # type: ignore[return-value,arg-type ]
880+ return AsyncFunction (client , str (ref ), streaming = streaming ) # type: ignore[return-value]
881+
882+ # TODO: Fix type inference for Function return type
883+ return Function (client , str (ref ), streaming = streaming ) # type: ignore[return-value]
0 commit comments