33from typing import TYPE_CHECKING , Dict , Union , Iterable
44from typing_extensions import Unpack
55
6+ from replicate .types .prediction_create_params import PredictionCreateParamsWithoutVersion
7+
68from ..types import PredictionOutput , PredictionCreateParams
79from .._types import NOT_GIVEN , NotGiven
810from .._utils import is_given
@@ -21,7 +23,7 @@ def run(
2123 * ,
2224 wait : Union [int , bool , NotGiven ] = NOT_GIVEN ,
2325 # use_file_output: Optional[bool] = True,
24- ** params : Unpack [PredictionCreateParams ],
26+ ** params : Unpack [PredictionCreateParamsWithoutVersion ],
2527) -> PredictionOutput | FileOutput | Iterable [FileOutput ] | Dict [str , FileOutput ]:
2628 from ._files import transform_output
2729
@@ -40,7 +42,8 @@ def run(
4042 params .setdefault ("prefer" , f"wait={ wait } " )
4143
4244 # TODO: support more ref types
43- prediction = client .predictions .create (version = ref , ** params )
45+ params_with_version : PredictionCreateParams = {** params , "version" : ref }
46+ prediction = client .predictions .create (** params_with_version )
4447
4548 # Currently the "Prefer: wait" interface will return a prediction with a status
4649 # of "processing" rather than a terminal state because it returns before the
@@ -91,7 +94,8 @@ async def async_run(
9194 params .setdefault ("prefer" , f"wait={ wait } " )
9295
9396 # TODO: support more ref types
94- prediction = await client .predictions .create (version = ref , ** params )
97+ params_with_version : PredictionCreateParams = {** params , "version" : ref }
98+ prediction = await client .predictions .create (** params_with_version )
9599
96100 # Currently the "Prefer: wait" interface will return a prediction with a status
97101 # of "processing" rather than a terminal state because it returns before the
0 commit comments