Skip to content

Commit 36a65e6

Browse files
committed
Update run() return type to Any
1 parent cfcf1a7 commit 36a65e6

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

examples/demo.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from typing import Iterable, cast
2-
31
import rich
42

5-
from replicate import FileOutput, ReplicateClient
3+
from replicate import ReplicateClient
64

75
client = ReplicateClient()
86

@@ -11,6 +9,6 @@
119
input={"prompt": "astronaut riding a rocket like a horse"},
1210
)
1311
rich.print(outputs)
14-
for index, output in enumerate(cast(Iterable[FileOutput], outputs)):
12+
for index, output in enumerate(outputs):
1513
with open(f"output_{index}.webp", "wb") as file:
1614
file.write(output.read())

src/replicate/_client.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Dict, Union, Mapping, Iterable
6+
from typing import TYPE_CHECKING, Any, Union, Mapping
77
from typing_extensions import Self, Unpack, override
88

99
import httpx
@@ -12,7 +12,7 @@
1212

1313
from . import _exceptions
1414
from ._qs import Querystring
15-
from .types import PredictionOutput, PredictionCreateParams
15+
from .types import PredictionCreateParams
1616
from ._types import (
1717
NOT_GIVEN,
1818
Omit,
@@ -36,9 +36,6 @@
3636
from .resources.webhooks import webhooks
3737
from .resources.deployments import deployments
3838

39-
if TYPE_CHECKING:
40-
from .lib._files import FileOutput
41-
4239
__all__ = [
4340
"Timeout",
4441
"Transport",
@@ -135,7 +132,7 @@ def run(
135132
*,
136133
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
137134
**params: Unpack[PredictionCreateParamsWithoutVersion],
138-
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
135+
) -> Any:
139136
"""Run a model and wait for its output."""
140137
from .lib._predictions import run
141138

@@ -330,7 +327,7 @@ async def run(
330327
*,
331328
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
332329
**params: Unpack[PredictionCreateParams],
333-
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
330+
) -> Any:
334331
"""Run a model and wait for its output."""
335332
from .lib._predictions import async_run
336333

src/replicate/types/prediction_create_params.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,4 @@ class PredictionCreateParamsWithoutVersion(TypedDict, total=False):
9595

9696
class PredictionCreateParams(PredictionCreateParamsWithoutVersion):
9797
version: Required[str]
98-
"""The ID of the model version that you want to run."""
98+
"""The ID of the model version that you want to run."""

0 commit comments

Comments
 (0)