|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import TYPE_CHECKING, Dict, Union, Iterable, Optional |
| 3 | +import time |
| 4 | +from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Iterator, Optional |
| 5 | +from collections.abc import AsyncIterator |
4 | 6 | from typing_extensions import Unpack |
5 | 7 |
|
6 | 8 | from replicate.lib._files import FileEncodingStrategy |
| 9 | +from replicate.lib._schema import make_schema_backwards_compatible |
| 10 | +from replicate.types.prediction import Prediction |
7 | 11 | from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion |
8 | 12 |
|
9 | 13 | from ..types import PredictionOutput, PredictionCreateParams |
@@ -71,7 +75,7 @@ def run( |
71 | 75 | params.setdefault("prefer", f"wait={wait}") |
72 | 76 |
|
73 | 77 | # Resolve ref to its components |
74 | | - _version, owner, name, version_id = resolve_reference(ref) |
| 78 | + version, owner, name, version_id = resolve_reference(ref) |
75 | 79 |
|
76 | 80 | prediction = None |
77 | 81 | if version_id is not None: |
@@ -104,14 +108,18 @@ def run( |
104 | 108 | # "processing". |
105 | 109 | in_terminal_state = is_blocking and prediction.status != "starting" |
106 | 110 | if not in_terminal_state: |
107 | | - # TODO: Return a "polling" iterator if the model has an output iterator array type. |
| 111 | + # Return a "polling" iterator if the model has an output iterator array type. |
| 112 | + if version and _has_output_iterator_array_type(version): |
| 113 | + return (transform_output(chunk, client) for chunk in output_iterator(prediction=prediction, client=client)) |
108 | 114 |
|
109 | 115 | prediction = client.predictions.wait(prediction.id) |
110 | 116 |
|
111 | 117 | if prediction.status == "failed": |
112 | 118 | raise ModelError(prediction) |
113 | 119 |
|
114 | | - # TODO: Return an iterator for completed output if the model has an output iterator array type. |
| 120 | + # Return an iterator for the completed prediction when needed. |
| 121 | + if version and _has_output_iterator_array_type(version) and prediction.output is not None: |
| 122 | + return (transform_output(chunk, client) for chunk in prediction.output) |
115 | 123 |
|
116 | 124 | if use_file_output: |
117 | 125 | return transform_output(prediction.output, client) # type: ignore[no-any-return] |
@@ -173,7 +181,7 @@ async def async_run( |
173 | 181 | params.setdefault("prefer", f"wait={wait}") |
174 | 182 |
|
175 | 183 | # Resolve ref to its components |
176 | | - _version, owner, name, version_id = resolve_reference(ref) |
| 184 | + version, owner, name, version_id = resolve_reference(ref) |
177 | 185 |
|
178 | 186 | prediction = None |
179 | 187 | if version_id is not None: |
@@ -210,16 +218,56 @@ async def async_run( |
210 | 218 | # "processing". |
211 | 219 | in_terminal_state = is_blocking and prediction.status != "starting" |
212 | 220 | if not in_terminal_state: |
213 | | - # TODO: Return a "polling" iterator if the model has an output iterator array type. |
| 221 | + # Return a "polling" iterator if the model has an output iterator array type. |
| 222 | + # if version and _has_output_iterator_array_type(version): |
| 223 | + # return ( |
| 224 | + # transform_output(chunk, client) |
| 225 | + # async for chunk in prediction.async_output_iterator() |
| 226 | + # ) |
214 | 227 |
|
215 | 228 | prediction = await client.predictions.wait(prediction.id) |
216 | 229 |
|
217 | 230 | if prediction.status == "failed": |
218 | 231 | raise ModelError(prediction) |
219 | 232 |
|
220 | | - # TODO: Return an iterator for completed output if the model has an output iterator array type. |
221 | | - |
| 233 | + # Return an iterator for completed output if the model has an output iterator array type. |
| 234 | + if version and _has_output_iterator_array_type(version) and prediction.output is not None: |
| 235 | + return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output)) |
222 | 236 | if use_file_output: |
223 | 237 | return transform_output(prediction.output, client) # type: ignore[no-any-return] |
224 | 238 |
|
225 | 239 | return prediction.output |
| 240 | + |
| 241 | + |
| 242 | +def _has_output_iterator_array_type(version: Version) -> bool: |
| 243 | + schema = make_schema_backwards_compatible(version.openapi_schema, version.cog_version) |
| 244 | + output = schema.get("components", {}).get("schemas", {}).get("Output", {}) |
| 245 | + return output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" # type: ignore[no-any-return] |
| 246 | + |
| 247 | + |
| 248 | +async def _make_async_iterator(list: List[Any]) -> AsyncIterator[Any]: |
| 249 | + for item in list: |
| 250 | + yield item |
| 251 | + |
| 252 | + |
| 253 | +def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]: |
| 254 | + """ |
| 255 | + Return an iterator of the prediction output. |
| 256 | + """ |
| 257 | + |
| 258 | + # TODO: check output is list |
| 259 | + previous_output: Any = prediction.output or [] |
| 260 | + while prediction.status not in ["succeeded", "failed", "canceled"]: |
| 261 | + output: Any = prediction.output or [] |
| 262 | + new_output = output[len(previous_output) :] |
| 263 | + yield from new_output |
| 264 | + previous_output = output |
| 265 | + time.sleep(client.poll_interval) |
| 266 | + prediction = client.predictions.get(prediction.id) |
| 267 | + |
| 268 | + if prediction.status == "failed": |
| 269 | + raise ModelError(prediction=prediction) |
| 270 | + |
| 271 | + output = prediction.output or [] |
| 272 | + new_output = output[len(previous_output) :] |
| 273 | + yield from new_output |
0 commit comments