Skip to content

Commit 80b28fc

Browse files
committed
clean up and use better import
1 parent c7216db commit 80b28fc

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

examples/run_a_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import rich
22

3-
from replicate import Replicate
3+
import replicate
44

5-
client = Replicate()
6-
7-
outputs = client.run(
5+
outputs = replicate.run(
86
"black-forest-labs/flux-schnell",
97
input={"prompt": "astronaut riding a rocket like a horse"},
108
)

src/replicate/lib/_predictions.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run(
9191

9292
# Return an iterator for the completed prediction when needed.
9393
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
94-
return (transform_output(chunk, client) for chunk in prediction.output) # type: ignore
94+
return (transform_output(chunk, client) for chunk in prediction.output) # type: ignore
9595

9696
if use_file_output:
9797
return transform_output(prediction.output, client) # type: ignore[no-any-return]
@@ -176,7 +176,7 @@ async def async_run(
176176

177177
# Return an iterator for completed output if the model has an output iterator array type.
178178
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
179-
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output)) # type: ignore
179+
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output)) # type: ignore
180180
if use_file_output:
181181
return transform_output(prediction.output, client) # type: ignore[no-any-return]
182182

@@ -199,10 +199,13 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
199199
Return an iterator of the prediction output.
200200
"""
201201

202-
# TODO: check output is list
203-
previous_output: Any = prediction.output or []
202+
# output can really be anything, but if we hit this then we know
203+
# it should be a list of something!
204+
if not isinstance(prediction.output, list):
205+
raise TypeError(f"Expected prediction output to be a list, got {type(prediction.output)}")
206+
previous_output: list[Any] = prediction.output or [] # type: ignore[union-attr]
204207
while prediction.status not in ["succeeded", "failed", "canceled"]:
205-
output: Any = prediction.output or []
208+
output: list[Any] = prediction.output or [] # type: ignore[union-attr]
206209
new_output = output[len(previous_output) :]
207210
yield from new_output
208211
previous_output = output
@@ -212,6 +215,6 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
212215
if prediction.status == "failed":
213216
raise ModelError(prediction=prediction)
214217

215-
output = prediction.output or []
218+
output: list[Any] = prediction.output or [] # type: ignore[union-attr]
216219
new_output = output[len(previous_output) :]
217220
yield from new_output

0 commit comments

Comments
 (0)