@@ -417,7 +417,13 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
417417 version = self ._version
418418
419419 if version :
420- prediction = self ._client .predictions .create (version = version , input = processed_inputs )
420+ if isinstance (version , VersionGetResponse ):
421+ version_id = version .id
422+ elif isinstance (version , dict ) and "id" in version :
423+ version_id = version ["id" ]
424+ else :
425+ version_id = str (version )
426+ prediction = self ._client .predictions .create (version = version_id , input = processed_inputs )
421427 else :
422428 prediction = self ._client .models .predictions .create (model = self ._model , input = processed_inputs )
423429
@@ -629,7 +635,13 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
629635 version = await self ._version ()
630636
631637 if version :
632- prediction = await self ._client .predictions .create (version = version , input = processed_inputs )
638+ if isinstance (version , VersionGetResponse ):
639+ version_id = version .id
640+ elif isinstance (version , dict ) and "id" in version :
641+ version_id = version ["id" ]
642+ else :
643+ version_id = str (version )
644+ prediction = await self ._client .predictions .create (version = version_id , input = processed_inputs )
633645 else :
634646 model = await self ._model ()
635647 prediction = await self ._client .models .predictions .create (
0 commit comments