1010from replicate .types .prediction import Prediction
1111from replicate .types .prediction_create_params import PredictionCreateParamsWithoutVersion
1212
13- from ..types import PredictionOutput , PredictionCreateParams
13+ from ..types import PredictionCreateParams
1414from .._types import NOT_GIVEN , NotGiven
1515from .._utils import is_given
1616from ._models import Model , Version , ModelVersionIdentifier , resolve_reference
@@ -29,7 +29,7 @@ def run(
2929 use_file_output : Optional [bool ] = True ,
3030 file_encoding_strategy : Optional ["FileEncodingStrategy" ] = None ,
3131 ** params : Unpack [PredictionCreateParamsWithoutVersion ],
32- ) -> PredictionOutput | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
32+ ) -> object | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
3333 from ._files import transform_output
3434
3535 if is_given (wait ) and "prefer" in params :
@@ -91,7 +91,7 @@ def run(
9191
9292 # Return an iterator for the completed prediction when needed.
9393 if version and _has_output_iterator_array_type (version ) and prediction .output is not None :
94- return (transform_output (chunk , client ) for chunk in prediction .output )
94+ return (transform_output (chunk , client ) for chunk in prediction .output ) # type: ignore
9595
9696 if use_file_output :
9797 return transform_output (prediction .output , client ) # type: ignore[no-any-return]
@@ -107,7 +107,7 @@ async def async_run(
107107 wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
108108 use_file_output : Optional [bool ] = True ,
109109 ** params : Unpack [PredictionCreateParamsWithoutVersion ],
110- ) -> PredictionOutput | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
110+ ) -> object | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
111111 from ._files import transform_output
112112
113113 if is_given (wait ) and "prefer" in params :
@@ -176,7 +176,7 @@ async def async_run(
176176
177177 # Return an iterator for completed output if the model has an output iterator array type.
178178 if version and _has_output_iterator_array_type (version ) and prediction .output is not None :
179- return (transform_output (chunk , client ) async for chunk in _make_async_iterator (prediction .output ))
179+ return (transform_output (chunk , client ) async for chunk in _make_async_iterator (prediction .output )) # type: ignore
180180 if use_file_output :
181181 return transform_output (prediction .output , client ) # type: ignore[no-any-return]
182182
@@ -207,7 +207,7 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
207207 yield from new_output
208208 previous_output = output
209209 time .sleep (client .poll_interval )
210- prediction = client .predictions .get (prediction .id )
210+ prediction = client .predictions .get (prediction_id = prediction .id )
211211
212212 if prediction .status == "failed" :
213213 raise ModelError (prediction = prediction )
0 commit comments