33from __future__ import annotations
44
55import os
6- from typing import TYPE_CHECKING , Any , Union , Mapping , Optional
7- from typing_extensions import Self , Unpack , override
6+ from typing import (
7+ TYPE_CHECKING ,
8+ Any ,
9+ Union ,
10+ Literal ,
11+ Mapping ,
12+ TypeVar ,
13+ Callable ,
14+ Iterator ,
15+ Optional ,
16+ AsyncIterator ,
17+ overload ,
18+ )
19+ from typing_extensions import Self , Unpack , ParamSpec , override
820
921import httpx
1022
1123from replicate .lib ._files import FileEncodingStrategy
12- from replicate .lib ._predictions import Model , Version , ModelVersionIdentifier
24+ from replicate .lib ._predictions_run import Model , Version , ModelVersionIdentifier
1325from replicate .types .prediction_create_params import PredictionCreateParamsWithoutVersion
1426
1527from . import _exceptions
4658 from .resources .webhooks .webhooks import WebhooksResource , AsyncWebhooksResource
4759 from .resources .deployments .deployments import DeploymentsResource , AsyncDeploymentsResource
4860
61+ if TYPE_CHECKING :
62+ from .lib ._predictions_use import Function , FunctionRef , AsyncFunction
63+
64+ Input = ParamSpec ("Input" )
65+ Output = TypeVar ("Output" )
66+
4967__all__ = [
5068 "Timeout" ,
5169 "Transport" ,
@@ -236,7 +254,7 @@ def run(
236254 ValueError: If the reference format is invalid
237255 TypeError: If both wait and prefer parameters are provided
238256 """
239- from .lib ._predictions import run
257+ from .lib ._predictions_run import run
240258
241259 return run (
242260 self ,
@@ -247,6 +265,42 @@ def run(
247265 ** params ,
248266 )
249267
268+ @overload
269+ def use (
270+ self ,
271+ ref : Union [str , "FunctionRef[Input, Output]" ],
272+ * ,
273+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
274+ streaming : Literal [False ] = False ,
275+ ) -> "Function[Input, Output]" : ...
276+
277+ @overload
278+ def use (
279+ self ,
280+ ref : Union [str , "FunctionRef[Input, Output]" ],
281+ * ,
282+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
283+ streaming : Literal [True ],
284+ ) -> "Function[Input, Iterator[Output]]" : ...
285+
286+ def use (
287+ self ,
288+ ref : Union [str , "FunctionRef[Input, Output]" ],
289+ * ,
290+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
291+ streaming : bool = False ,
292+ ) -> Union ["Function[Input, Output]" , "Function[Input, Iterator[Output]]" ]:
293+ """
294+ Use a Replicate model as a function.
295+
296+ Example:
297+ flux_dev = replicate.use("black-forest-labs/flux-dev")
298+ output = flux_dev(prompt="make me a sandwich")
299+ """
300+ from .lib ._predictions_use import use as _use
301+
302+ return _use (self , ref , hint = hint , streaming = streaming )
303+
250304 def copy (
251305 self ,
252306 * ,
@@ -510,7 +564,7 @@ async def run(
510564 ValueError: If the reference format is invalid
511565 TypeError: If both wait and prefer parameters are provided
512566 """
513- from .lib ._predictions import async_run
567+ from .lib ._predictions_run import async_run
514568
515569 return await async_run (
516570 self ,
@@ -521,6 +575,42 @@ async def run(
521575 ** params ,
522576 )
523577
578+ @overload
579+ def use (
580+ self ,
581+ ref : Union [str , "FunctionRef[Input, Output]" ],
582+ * ,
583+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
584+ streaming : Literal [False ] = False ,
585+ ) -> "AsyncFunction[Input, Output]" : ...
586+
587+ @overload
588+ def use (
589+ self ,
590+ ref : Union [str , "FunctionRef[Input, Output]" ],
591+ * ,
592+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
593+ streaming : Literal [True ],
594+ ) -> "AsyncFunction[Input, AsyncIterator[Output]]" : ...
595+
596+ def use (
597+ self ,
598+ ref : Union [str , "FunctionRef[Input, Output]" ],
599+ * ,
600+ hint : Optional [Callable ["Input" , "Output" ]] = None ,
601+ streaming : bool = False ,
602+ ) -> Union ["AsyncFunction[Input, Output]" , "AsyncFunction[Input, AsyncIterator[Output]]" ]:
603+ """
604+ Use a Replicate model as an async function.
605+
606+ Example:
607+ flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True)
608+ output = await flux_dev(prompt="make me a sandwich")
609+ """
610+ from .lib ._predictions_use import use as _use
611+
612+ return _use (self , ref , hint = hint , streaming = streaming )
613+
524614 def copy (
525615 self ,
526616 * ,
0 commit comments