Skip to content

Commit 652b8c7

Browse files
committed
Use VersionGetResponse type for latest_version
1 parent c42f9cf commit 652b8c7

File tree

6 files changed

+25
-7
lines changed

6 files changed

+25
-7
lines changed

src/replicate/lib/_predictions_use.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,13 @@ def create(self, *_: Input.args, **inputs: Input.kwargs) -> Run[Output]:
417417
version = self._version
418418

419419
if version:
420-
prediction = self._client.predictions.create(version=version, input=processed_inputs)
420+
if isinstance(version, VersionGetResponse):
421+
version_id = version.id
422+
elif isinstance(version, dict) and "id" in version:
423+
version_id = version["id"]
424+
else:
425+
version_id = str(version)
426+
prediction = self._client.predictions.create(version=version_id, input=processed_inputs)
421427
else:
422428
prediction = self._client.models.predictions.create(model=self._model, input=processed_inputs)
423429

@@ -629,7 +635,13 @@ async def create(self, *_: Input.args, **inputs: Input.kwargs) -> AsyncRun[Outpu
629635
version = await self._version()
630636

631637
if version:
632-
prediction = await self._client.predictions.create(version=version, input=processed_inputs)
638+
if isinstance(version, VersionGetResponse):
639+
version_id = version.id
640+
elif isinstance(version, dict) and "id" in version:
641+
version_id = version["id"]
642+
else:
643+
version_id = str(version)
644+
prediction = await self._client.predictions.create(version=version_id, input=processed_inputs)
633645
else:
634646
model = await self._model()
635647
prediction = await self._client.models.predictions.create(

src/replicate/types/collection_get_response.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import List, Optional
44
from typing_extensions import Literal
55

6+
from replicate.types.models.version_get_response import VersionGetResponse
7+
68
from .._models import BaseModel
79

810
__all__ = ["CollectionGetResponse", "Model"]
@@ -21,7 +23,7 @@ class Model(BaseModel):
2123
github_url: Optional[str] = None
2224
"""A URL for the model's source code on GitHub"""
2325

24-
latest_version: Optional[object] = None
26+
latest_version: Optional[VersionGetResponse] = None
2527
"""The model's latest version"""
2628

2729
license_url: Optional[str] = None

src/replicate/types/model_create_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing_extensions import Literal
55

66
from .._models import BaseModel
7+
from .models.version_get_response import VersionGetResponse
78

89
__all__ = ["ModelCreateResponse"]
910

@@ -21,7 +22,7 @@ class ModelCreateResponse(BaseModel):
2122
github_url: Optional[str] = None
2223
"""A URL for the model's source code on GitHub"""
2324

24-
latest_version: Optional[object] = None
25+
latest_version: Optional[VersionGetResponse] = None
2526
"""The model's latest version"""
2627

2728
license_url: Optional[str] = None

src/replicate/types/model_get_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing_extensions import Literal
55

66
from .._models import BaseModel
7+
from .models.version_get_response import VersionGetResponse
78

89
__all__ = ["ModelGetResponse"]
910

@@ -21,7 +22,7 @@ class ModelGetResponse(BaseModel):
2122
github_url: Optional[str] = None
2223
"""A URL for the model's source code on GitHub"""
2324

24-
latest_version: Optional[object] = None
25+
latest_version: Optional[VersionGetResponse] = None
2526
"""The model's latest version"""
2627

2728
license_url: Optional[str] = None

src/replicate/types/model_list_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing_extensions import Literal
55

66
from .._models import BaseModel
7+
from .models.version_get_response import VersionGetResponse
78

89
__all__ = ["ModelListResponse"]
910

@@ -21,7 +22,7 @@ class ModelListResponse(BaseModel):
2122
github_url: Optional[str] = None
2223
"""A URL for the model's source code on GitHub"""
2324

24-
latest_version: Optional[object] = None
25+
latest_version: Optional[VersionGetResponse] = None
2526
"""The model's latest version"""
2627

2728
license_url: Optional[str] = None

src/replicate/types/model_search_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing_extensions import Literal
55

66
from .._models import BaseModel
7+
from .models.version_get_response import VersionGetResponse
78

89
__all__ = ["ModelSearchResponse"]
910

@@ -21,7 +22,7 @@ class ModelSearchResponse(BaseModel):
2122
github_url: Optional[str] = None
2223
"""A URL for the model's source code on GitHub"""
2324

24-
latest_version: Optional[object] = None
25+
latest_version: Optional[VersionGetResponse] = None
2526
"""The model's latest version"""
2627

2728
license_url: Optional[str] = None

0 commit comments

Comments
 (0)