@@ -91,7 +91,7 @@ def run(
9191
9292 # Return an iterator for the completed prediction when needed.
9393 if version and _has_output_iterator_array_type (version ) and prediction .output is not None :
94- return (transform_output (chunk , client ) for chunk in prediction .output ) # type: ignore
94+ return (transform_output (chunk , client ) for chunk in prediction .output ) # type: ignore
9595
9696 if use_file_output :
9797 return transform_output (prediction .output , client ) # type: ignore[no-any-return]
@@ -176,7 +176,7 @@ async def async_run(
176176
177177 # Return an iterator for completed output if the model has an output iterator array type.
178178 if version and _has_output_iterator_array_type (version ) and prediction .output is not None :
179- return (transform_output (chunk , client ) async for chunk in _make_async_iterator (prediction .output )) # type: ignore
179+ return (transform_output (chunk , client ) async for chunk in _make_async_iterator (prediction .output )) # type: ignore
180180 if use_file_output :
181181 return transform_output (prediction .output , client ) # type: ignore[no-any-return]
182182
@@ -199,10 +199,13 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
199199 Return an iterator of the prediction output.
200200 """
201201
202- # TODO: check output is list
203- previous_output : Any = prediction .output or []
202+ # output can really be anything, but if we hit this then we know
203+ # it should be a list of something!
204+ if not isinstance (prediction .output , list ):
205+ raise TypeError (f"Expected prediction output to be a list, got { type (prediction .output )} " )
206+ previous_output : list [Any ] = prediction .output or [] # type: ignore[union-attr]
204207 while prediction .status not in ["succeeded" , "failed" , "canceled" ]:
205- output : Any = prediction .output or []
208+ output : list [ Any ] = prediction .output or [] # type: ignore[union-attr ]
206209 new_output = output [len (previous_output ) :]
207210 yield from new_output
208211 previous_output = output
@@ -212,6 +215,6 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
212215 if prediction .status == "failed" :
213216 raise ModelError (prediction = prediction )
214217
215- output = prediction .output or []
218+ output : list [ Any ] = prediction .output or [] # type: ignore[union-attr ]
216219 new_output = output [len (previous_output ) :]
217220 yield from new_output
0 commit comments