Skip to content

Commit a7b12dd

Browse files
committed
feat: add replicate.use()
1 parent 0d1be44 commit a7b12dd

File tree

3 files changed

+844
-5
lines changed

3 files changed

+844
-5
lines changed

src/replicate/_client.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,25 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping, Optional
7-
from typing_extensions import Self, Unpack, override
6+
from typing import (
7+
TYPE_CHECKING,
8+
Any,
9+
Union,
10+
Literal,
11+
Mapping,
12+
TypeVar,
13+
Callable,
14+
Iterator,
15+
Optional,
16+
AsyncIterator,
17+
overload,
18+
)
19+
from typing_extensions import Self, Unpack, ParamSpec, override
820

921
import httpx
1022

1123
from replicate.lib._files import FileEncodingStrategy
12-
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
24+
from replicate.lib._predictions_run import Model, Version, ModelVersionIdentifier
1325
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
1426

1527
from . import _exceptions
@@ -46,6 +58,12 @@
4658
from .resources.webhooks.webhooks import WebhooksResource, AsyncWebhooksResource
4759
from .resources.deployments.deployments import DeploymentsResource, AsyncDeploymentsResource
4860

61+
if TYPE_CHECKING:
62+
from .lib._predictions_use import Function, FunctionRef, AsyncFunction
63+
64+
Input = ParamSpec("Input")
65+
Output = TypeVar("Output")
66+
4967
__all__ = [
5068
"Timeout",
5169
"Transport",
@@ -236,7 +254,7 @@ def run(
236254
ValueError: If the reference format is invalid
237255
TypeError: If both wait and prefer parameters are provided
238256
"""
239-
from .lib._predictions import run
257+
from .lib._predictions_run import run
240258

241259
return run(
242260
self,
@@ -247,6 +265,42 @@ def run(
247265
**params,
248266
)
249267

268+
@overload
269+
def use(
270+
self,
271+
ref: Union[str, "FunctionRef[Input, Output]"],
272+
*,
273+
hint: Optional[Callable["Input", "Output"]] = None,
274+
streaming: Literal[False] = False,
275+
) -> "Function[Input, Output]": ...
276+
277+
@overload
278+
def use(
279+
self,
280+
ref: Union[str, "FunctionRef[Input, Output]"],
281+
*,
282+
hint: Optional[Callable["Input", "Output"]] = None,
283+
streaming: Literal[True],
284+
) -> "Function[Input, Iterator[Output]]": ...
285+
286+
def use(
287+
self,
288+
ref: Union[str, "FunctionRef[Input, Output]"],
289+
*,
290+
hint: Optional[Callable["Input", "Output"]] = None,
291+
streaming: bool = False,
292+
) -> Union["Function[Input, Output]", "Function[Input, Iterator[Output]]"]:
293+
"""
294+
Use a Replicate model as a function.
295+
296+
Example:
297+
flux_dev = replicate.use("black-forest-labs/flux-dev")
298+
output = flux_dev(prompt="make me a sandwich")
299+
"""
300+
from .lib._predictions_use import use as _use
301+
302+
return _use(self, ref, hint=hint, streaming=streaming)
303+
250304
def copy(
251305
self,
252306
*,
@@ -510,7 +564,7 @@ async def run(
510564
ValueError: If the reference format is invalid
511565
TypeError: If both wait and prefer parameters are provided
512566
"""
513-
from .lib._predictions import async_run
567+
from .lib._predictions_run import async_run
514568

515569
return await async_run(
516570
self,
@@ -521,6 +575,42 @@ async def run(
521575
**params,
522576
)
523577

578+
@overload
579+
def use(
580+
self,
581+
ref: Union[str, "FunctionRef[Input, Output]"],
582+
*,
583+
hint: Optional[Callable["Input", "Output"]] = None,
584+
streaming: Literal[False] = False,
585+
) -> "AsyncFunction[Input, Output]": ...
586+
587+
@overload
588+
def use(
589+
self,
590+
ref: Union[str, "FunctionRef[Input, Output]"],
591+
*,
592+
hint: Optional[Callable["Input", "Output"]] = None,
593+
streaming: Literal[True],
594+
) -> "AsyncFunction[Input, AsyncIterator[Output]]": ...
595+
596+
def use(
597+
self,
598+
ref: Union[str, "FunctionRef[Input, Output]"],
599+
*,
600+
hint: Optional[Callable["Input", "Output"]] = None,
601+
streaming: bool = False,
602+
) -> Union["AsyncFunction[Input, Output]", "AsyncFunction[Input, AsyncIterator[Output]]"]:
603+
"""
604+
Use a Replicate model as an async function.
605+
606+
Example:
607+
flux_dev = replicate.use("black-forest-labs/flux-dev", use_async=True)
608+
output = await flux_dev(prompt="make me a sandwich")
609+
"""
610+
from .lib._predictions_use import use as _use
611+
612+
return _use(self, ref, hint=hint, streaming=streaming)
613+
524614
def copy(
525615
self,
526616
*,
File renamed without changes.

0 commit comments

Comments
 (0)