@@ -197,22 +197,19 @@ def _resolve_ref(obj: Any) -> Any:
197197T = TypeVar ("T" )
198198
199199
200- class OutputIterator (Generic [T ]):
200+ class SyncOutputIterator (Generic [T ]):
201201 """
202- An iterator wrapper that handles both regular iteration and string conversion.
203- Supports both sync and async iteration patterns.
202+ A synchronous iterator wrapper that handles both regular iteration and string conversion.
204203 """
205204
206205 def __init__ (
207206 self ,
208207 iterator_factory : Callable [[], Iterator [T ]],
209- async_iterator_factory : Callable [[], AsyncIterator [T ]],
210208 schema : Dict [str , Any ],
211209 * ,
212210 is_concatenate : bool ,
213211 ) -> None :
214212 self .iterator_factory = iterator_factory
215- self .async_iterator_factory = async_iterator_factory
216213 self .schema = schema
217214 self .is_concatenate = is_concatenate
218215
@@ -224,6 +221,30 @@ def __iter__(self) -> Iterator[T]:
224221 else :
225222 yield _process_iterator_item (chunk , self .schema )
226223
224+ @override
225+ def __str__ (self ) -> str :
226+ """Convert to string by joining segments with empty string."""
227+ if self .is_concatenate :
228+ return "" .join ([str (segment ) for segment in self .iterator_factory ()])
229+ return str (list (self .iterator_factory ()))
230+
231+
232+ class AsyncOutputIterator (Generic [T ]):
233+ """
234+ An asynchronous iterator wrapper that handles both regular iteration and string conversion.
235+ """
236+
237+ def __init__ (
238+ self ,
239+ async_iterator_factory : Callable [[], AsyncIterator [T ]],
240+ schema : Dict [str , Any ],
241+ * ,
242+ is_concatenate : bool ,
243+ ) -> None :
244+ self .async_iterator_factory = async_iterator_factory
245+ self .schema = schema
246+ self .is_concatenate = is_concatenate
247+
227248 async def __aiter__ (self ) -> AsyncIterator [T ]:
228249 """Iterate over output items asynchronously."""
229250 async for chunk in self .async_iterator_factory ():
@@ -232,15 +253,8 @@ async def __aiter__(self) -> AsyncIterator[T]:
232253 else :
233254 yield _process_iterator_item (chunk , self .schema )
234255
235- @override
236- def __str__ (self ) -> str :
237- """Convert to string by joining segments with empty string."""
238- if self .is_concatenate :
239- return "" .join ([str (segment ) for segment in self .iterator_factory ()])
240- return str (list (self .iterator_factory ()))
241-
242256 def __await__ (self ) -> Generator [Any , None , Union [List [T ], str ]]:
243- """Make OutputIterator awaitable, returning appropriate result based on concatenate mode."""
257+ """Make AsyncOutputIterator awaitable, returning appropriate result based on concatenate mode."""
244258
245259 async def _collect_result () -> Union [List [T ], str ]:
246260 if self .is_concatenate :
@@ -329,10 +343,12 @@ class Run(Generic[O]):
329343 Represents a running prediction with access to the underlying schema.
330344 """
331345
346+ _client : Client
332347 _prediction : Prediction
333348 _schema : Dict [str , Any ]
334349
335- def __init__ (self , * , prediction : Prediction , schema : Dict [str , Any ], streaming : bool ) -> None :
350+ def __init__ (self , * , client : Client , prediction : Prediction , schema : Dict [str , Any ], streaming : bool ) -> None :
351+ self ._client = client
336352 self ._prediction = prediction
337353 self ._schema = schema
338354 self ._streaming = streaming
@@ -342,22 +358,21 @@ def output(self) -> O:
342358 Return the output. For iterator types, returns immediately without waiting.
343359 For non-iterator types, waits for completion.
344360 """
345- # Return an OutputIterator immediately when streaming, we do this for all
361+ # Return a SyncOutputIterator immediately when streaming, we do this for all
346362 # model return types regardless of whether they return an iterator.
347363 if self ._streaming :
348364 is_concatenate = _has_concatenate_iterator_output_type (self ._schema )
349365 return cast (
350366 O ,
351- OutputIterator (
352- self ._prediction .output_iterator ,
353- self ._prediction .async_output_iterator ,
367+ SyncOutputIterator (
368+ self ._output_iterator ,
354369 self ._schema ,
355370 is_concatenate = is_concatenate ,
356371 ),
357372 )
358373
359374 # For non-streaming, wait for completion and process output
360- self ._prediction . wait ()
375+ self ._prediction = self . _client . predictions . wait (prediction_id = self . _prediction . id )
361376
362377 if self ._prediction .status == "failed" :
363378 raise ModelError (self ._prediction )
@@ -375,10 +390,36 @@ def logs(self) -> Optional[str]:
375390 """
376391 Fetch and return the logs from the prediction.
377392 """
378- self ._prediction . reload ( )
393+ self ._prediction = self . _client . predictions . get ( prediction_id = self . _prediction . id )
379394
380395 return self ._prediction .logs
381396
397+ def _output_iterator (self ) -> Iterator [Any ]:
398+ """
399+ Return an iterator of the prediction output.
400+ """
401+ if self ._prediction .status in ["succeeded" , "failed" , "canceled" ] and self ._prediction .output is not None :
402+ yield from self ._prediction .output
403+
404+ # TODO: check output is list
405+ previous_output = self ._prediction .output or []
406+ while self ._prediction .status not in ["succeeded" , "failed" , "canceled" ]:
407+ output = self ._prediction .output or []
408+ new_output = output [len (previous_output ) :]
409+ yield from new_output
410+ previous_output = output
411+ import time
412+
413+ time .sleep (self ._client .poll_interval )
414+ self ._prediction = self ._client .predictions .get (prediction_id = self ._prediction .id )
415+
416+ if self ._prediction .status == "failed" :
417+ raise ModelError (self ._prediction )
418+
419+ output = self ._prediction .output or []
420+ new_output = output [len (previous_output ) :]
421+ yield from new_output
422+
382423
383424class Function (Generic [Input , Output ]):
384425 """
@@ -401,10 +442,10 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
401442 """
402443 Start a prediction with the specified inputs.
403444 """
404- # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
445+ # Process inputs to convert concatenate SyncOutputIterators to strings and URLPath to URLs
405446 processed_inputs = {}
406447 for key , value in inputs .items ():
407- if isinstance (value , OutputIterator ):
448+ if isinstance (value , SyncOutputIterator ):
408449 if value .is_concatenate :
409450 processed_inputs [key ] = str (value )
410451 else :
@@ -428,6 +469,7 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
428469 prediction = self ._client .models .predictions .create (model = self ._model , input = processed_inputs )
429470
430471 return Run (
472+ client = self ._client ,
431473 prediction = prediction ,
432474 schema = self .openapi_schema (),
433475 streaming = self ._streaming ,
@@ -511,10 +553,12 @@ class AsyncRun(Generic[O]):
511553 Represents a running prediction with access to its version (async version).
512554 """
513555
556+ _client : AsyncClient
514557 _prediction : Prediction
515558 _schema : Dict [str , Any ]
516559
517- def __init__ (self , * , prediction : Prediction , schema : Dict [str , Any ], streaming : bool ) -> None :
560+ def __init__ (self , * , client : AsyncClient , prediction : Prediction , schema : Dict [str , Any ], streaming : bool ) -> None :
561+ self ._client = client
518562 self ._prediction = prediction
519563 self ._schema = schema
520564 self ._streaming = streaming
@@ -524,22 +568,21 @@ async def output(self) -> O:
524568 Return the output. For iterator types, returns immediately without waiting.
525569 For non-iterator types, waits for completion.
526570 """
527- # Return an OutputIterator immediately when streaming, we do this for all
571+ # Return an AsyncOutputIterator immediately when streaming, we do this for all
528572 # model return types regardless of whether they return an iterator.
529573 if self ._streaming :
530574 is_concatenate = _has_concatenate_iterator_output_type (self ._schema )
531575 return cast (
532576 O ,
533- OutputIterator (
534- self ._prediction .output_iterator ,
535- self ._prediction .async_output_iterator ,
577+ AsyncOutputIterator (
578+ self ._async_output_iterator ,
536579 self ._schema ,
537580 is_concatenate = is_concatenate ,
538581 ),
539582 )
540583
541584 # For non-streaming, wait for completion and process output
542- await self ._prediction . async_wait ( )
585+ self . _prediction = await self ._client . predictions . wait ( prediction_id = self . _prediction . id )
543586
544587 if self ._prediction .status == "failed" :
545588 raise ModelError (self ._prediction )
@@ -557,10 +600,39 @@ async def logs(self) -> Optional[str]:
557600 """
558601 Fetch and return the logs from the prediction asynchronously.
559602 """
560- await self ._prediction . async_reload ( )
603+ self . _prediction = await self ._client . predictions . async_get ( prediction_id = self . _prediction . id )
561604
562605 return self ._prediction .logs
563606
607+ async def _async_output_iterator (self ) -> AsyncIterator [Any ]:
608+ """
609+ Return an asynchronous iterator of the prediction output.
610+ """
611+ if self ._prediction .status in ["succeeded" , "failed" , "canceled" ] and self ._prediction .output is not None :
612+ for item in self ._prediction .output :
613+ yield item
614+
615+ # TODO: check output is list
616+ previous_output = self ._prediction .output or []
617+ while self ._prediction .status not in ["succeeded" , "failed" , "canceled" ]:
618+ output = self ._prediction .output or []
619+ new_output = output [len (previous_output ) :]
620+ for item in new_output :
621+ yield item
622+ previous_output = output
623+ import asyncio
624+
625+ await asyncio .sleep (self ._client .poll_interval )
626+ self ._prediction = await self ._client .predictions .async_get (prediction_id = self ._prediction .id )
627+
628+ if self ._prediction .status == "failed" :
629+ raise ModelError (self ._prediction )
630+
631+ output = self ._prediction .output or []
632+ new_output = output [len (previous_output ) :]
633+ for item in new_output :
634+ yield item
635+
564636
565637class AsyncFunction (Generic [Input , Output ]):
566638 """
@@ -622,10 +694,10 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
622694 """
623695 Start a prediction with the specified inputs asynchronously.
624696 """
625- # Process inputs to convert concatenate OutputIterators to strings and URLPath to URLs
697+ # Process inputs to convert concatenate AsyncOutputIterators to strings and URLPath to URLs
626698 processed_inputs = {}
627699 for key , value in inputs .items ():
628- if isinstance (value , OutputIterator ):
700+ if isinstance (value , AsyncOutputIterator ):
629701 processed_inputs [key ] = await value
630702 elif url := get_path_url (value ):
631703 processed_inputs [key ] = url
@@ -649,6 +721,7 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
649721 )
650722
651723 return AsyncRun (
724+ client = self ._client ,
652725 prediction = prediction ,
653726 schema = await self .openapi_schema (),
654727 streaming = self ._streaming ,
0 commit comments