Skip to content

Commit f41cc05

Browse files
committed
Make Run and AsyncRun iterable
1 parent 652b8c7 commit f41cc05

File tree

1 file changed

+104
-31
lines changed

1 file changed

+104
-31
lines changed

src/replicate/lib/_predictions_use.py

Lines changed: 104 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -197,22 +197,19 @@ def _resolve_ref(obj: Any) -> Any:
197197
T = TypeVar("T")
198198

199199

200-
class OutputIterator(Generic[T]):
200+
class SyncOutputIterator(Generic[T]):
201201
"""
202-
An iterator wrapper that handles both regular iteration and string conversion.
203-
Supports both sync and async iteration patterns.
202+
A synchronous iterator wrapper that handles both regular iteration and string conversion.
204203
"""
205204

206205
def __init__(
207206
self,
208207
iterator_factory: Callable[[], Iterator[T]],
209-
async_iterator_factory: Callable[[], AsyncIterator[T]],
210208
schema: Dict[str, Any],
211209
*,
212210
is_concatenate: bool,
213211
) -> None:
214212
self.iterator_factory = iterator_factory
215-
self.async_iterator_factory = async_iterator_factory
216213
self.schema = schema
217214
self.is_concatenate = is_concatenate
218215

@@ -224,6 +221,30 @@ def __iter__(self) -> Iterator[T]:
224221
else:
225222
yield _process_iterator_item(chunk, self.schema)
226223

224+
@override
225+
def __str__(self) -> str:
226+
"""Convert to string by joining segments with empty string."""
227+
if self.is_concatenate:
228+
return "".join([str(segment) for segment in self.iterator_factory()])
229+
return str(list(self.iterator_factory()))
230+
231+
232+
class AsyncOutputIterator(Generic[T]):
233+
"""
234+
An asynchronous iterator wrapper that handles both regular iteration and string conversion.
235+
"""
236+
237+
def __init__(
238+
self,
239+
async_iterator_factory: Callable[[], AsyncIterator[T]],
240+
schema: Dict[str, Any],
241+
*,
242+
is_concatenate: bool,
243+
) -> None:
244+
self.async_iterator_factory = async_iterator_factory
245+
self.schema = schema
246+
self.is_concatenate = is_concatenate
247+
227248
async def __aiter__(self) -> AsyncIterator[T]:
228249
"""Iterate over output items asynchronously."""
229250
async for chunk in self.async_iterator_factory():
@@ -232,15 +253,8 @@ async def __aiter__(self) -> AsyncIterator[T]:
232253
else:
233254
yield _process_iterator_item(chunk, self.schema)
234255

235-
@override
236-
def __str__(self) -> str:
237-
"""Convert to string by joining segments with empty string."""
238-
if self.is_concatenate:
239-
return "".join([str(segment) for segment in self.iterator_factory()])
240-
return str(list(self.iterator_factory()))
241-
242256
def __await__(self) -> Generator[Any, None, Union[List[T], str]]:
243-
"""Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257+
"""Make AsyncOutputIterator awaitable, returning appropriate result based on concatenate mode."""
244258

245259
async def _collect_result() -> Union[List[T], str]:
246260
if self.is_concatenate:
@@ -329,10 +343,12 @@ class Run(Generic[O]):
329343
Represents a running prediction with access to the underlying schema.
330344
"""
331345

346+
_client: Client
332347
_prediction: Prediction
333348
_schema: Dict[str, Any]
334349

335-
def __init__(self, *, prediction: Prediction, schema: Dict[str, Any], streaming: bool) -> None:
350+
def __init__(self, *, client: Client, prediction: Prediction, schema: Dict[str, Any], streaming: bool) -> None:
351+
self._client = client
336352
self._prediction = prediction
337353
self._schema = schema
338354
self._streaming = streaming
@@ -342,22 +358,21 @@ def output(self) -> O:
342358
Return the output. For iterator types, returns immediately without waiting.
343359
For non-iterator types, waits for completion.
344360
"""
345-
# Return an OutputIterator immediately when streaming, we do this for all
361+
# Return a SyncOutputIterator immediately when streaming, we do this for all
346362
# model return types regardless of whether they return an iterator.
347363
if self._streaming:
348364
is_concatenate = _has_concatenate_iterator_output_type(self._schema)
349365
return cast(
350366
O,
351-
OutputIterator(
352-
self._prediction.output_iterator,
353-
self._prediction.async_output_iterator,
367+
SyncOutputIterator(
368+
self._output_iterator,
354369
self._schema,
355370
is_concatenate=is_concatenate,
356371
),
357372
)
358373

359374
# For non-streaming, wait for completion and process output
360-
self._prediction.wait()
375+
self._prediction = self._client.predictions.wait(prediction_id=self._prediction.id)
361376

362377
if self._prediction.status == "failed":
363378
raise ModelError(self._prediction)
@@ -375,10 +390,36 @@ def logs(self) -> Optional[str]:
375390
"""
376391
Fetch and return the logs from the prediction.
377392
"""
378-
self._prediction.reload()
393+
self._prediction = self._client.predictions.get(prediction_id=self._prediction.id)
379394

380395
return self._prediction.logs
381396

397+
def _output_iterator(self) -> Iterator[Any]:
398+
"""
399+
Return an iterator of the prediction output.
400+
"""
401+
if self._prediction.status in ["succeeded", "failed", "canceled"] and self._prediction.output is not None:
402+
yield from self._prediction.output
403+
404+
# TODO: check output is list
405+
previous_output = self._prediction.output or []
406+
while self._prediction.status not in ["succeeded", "failed", "canceled"]:
407+
output = self._prediction.output or []
408+
new_output = output[len(previous_output) :]
409+
yield from new_output
410+
previous_output = output
411+
import time
412+
413+
time.sleep(self._client.poll_interval)
414+
self._prediction = self._client.predictions.get(prediction_id=self._prediction.id)
415+
416+
if self._prediction.status == "failed":
417+
raise ModelError(self._prediction)
418+
419+
output = self._prediction.output or []
420+
new_output = output[len(previous_output) :]
421+
yield from new_output
422+
382423

383424
class Function(Generic[Input, Output]):
384425
"""
@@ -401,10 +442,10 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
401442
"""
402443
Start a prediction with the specified inputs.
403444
"""
404-
# Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
445+
# Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs
405446
processed_inputs = {}
406447
for key, value in inputs.items():
407-
if isinstance(value, OutputIterator):
448+
if isinstance(value, SyncOutputIterator):
408449
if value.is_concatenate:
409450
processed_inputs[key] = str(value)
410451
else:
@@ -428,6 +469,7 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
428469
prediction = self._client.models.predictions.create(model=self._model, input=processed_inputs)
429470

430471
return Run(
472+
client=self._client,
431473
prediction=prediction,
432474
schema=self.openapi_schema(),
433475
streaming=self._streaming,
@@ -511,10 +553,12 @@ class AsyncRun(Generic[O]):
511553
Represents a running prediction with access to its version (async version).
512554
"""
513555

556+
_client: AsyncClient
514557
_prediction: Prediction
515558
_schema: Dict[str, Any]
516559

517-
def __init__(self, *, prediction: Prediction, schema: Dict[str, Any], streaming: bool) -> None:
560+
def __init__(self, *, client: AsyncClient, prediction: Prediction, schema: Dict[str, Any], streaming: bool) -> None:
561+
self._client = client
518562
self._prediction = prediction
519563
self._schema = schema
520564
self._streaming = streaming
@@ -524,22 +568,21 @@ async def output(self) -> O:
524568
Return the output. For iterator types, returns immediately without waiting.
525569
For non-iterator types, waits for completion.
526570
"""
527-
# Return an OutputIterator immediately when streaming, we do this for all
571+
# Return an AsyncOutputIterator immediately when streaming, we do this for all
528572
# model return types regardless of whether they return an iterator.
529573
if self._streaming:
530574
is_concatenate = _has_concatenate_iterator_output_type(self._schema)
531575
return cast(
532576
O,
533-
OutputIterator(
534-
self._prediction.output_iterator,
535-
self._prediction.async_output_iterator,
577+
AsyncOutputIterator(
578+
self._async_output_iterator,
536579
self._schema,
537580
is_concatenate=is_concatenate,
538581
),
539582
)
540583

541584
# For non-streaming, wait for completion and process output
542-
await self._prediction.async_wait()
585+
self._prediction = await self._client.predictions.wait(prediction_id=self._prediction.id)
543586

544587
if self._prediction.status == "failed":
545588
raise ModelError(self._prediction)
@@ -557,10 +600,39 @@ async def logs(self) -> Optional[str]:
557600
"""
558601
Fetch and return the logs from the prediction asynchronously.
559602
"""
560-
await self._prediction.async_reload()
603+
self._prediction = await self._client.predictions.async_get(prediction_id=self._prediction.id)
561604

562605
return self._prediction.logs
563606

607+
async def _async_output_iterator(self) -> AsyncIterator[Any]:
608+
"""
609+
Return an asynchronous iterator of the prediction output.
610+
"""
611+
if self._prediction.status in ["succeeded", "failed", "canceled"] and self._prediction.output is not None:
612+
for item in self._prediction.output:
613+
yield item
614+
615+
# TODO: check output is list
616+
previous_output = self._prediction.output or []
617+
while self._prediction.status not in ["succeeded", "failed", "canceled"]:
618+
output = self._prediction.output or []
619+
new_output = output[len(previous_output) :]
620+
for item in new_output:
621+
yield item
622+
previous_output = output
623+
import asyncio
624+
625+
await asyncio.sleep(self._client.poll_interval)
626+
self._prediction = await self._client.predictions.async_get(prediction_id=self._prediction.id)
627+
628+
if self._prediction.status == "failed":
629+
raise ModelError(self._prediction)
630+
631+
output = self._prediction.output or []
632+
new_output = output[len(previous_output) :]
633+
for item in new_output:
634+
yield item
635+
564636

565637
class AsyncFunction(Generic[Input, Output]):
566638
"""
@@ -622,10 +694,10 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
622694
"""
623695
Start a prediction with the specified inputs asynchronously.
624696
"""
625-
# Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
697+
# Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs
626698
processed_inputs = {}
627699
for key, value in inputs.items():
628-
if isinstance(value, OutputIterator):
700+
if isinstance(value, AsyncOutputIterator):
629701
processed_inputs[key] = await value
630702
elif url := get_path_url(value):
631703
processed_inputs[key] = url
@@ -649,6 +721,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
649721
)
650722

651723
return AsyncRun(
724+
client=self._client,
652725
prediction=prediction,
653726
schema=await self.openapi_schema(),
654727
streaming=self._streaming,

0 commit comments

Comments
 (0)