Skip to content

Commit cfcf1a7

Browse files
dtmeadowsdgellow
authored andcommitted
fix linter issues
1 parent 741df3a commit cfcf1a7

File tree

5 files changed

+25
-14
lines changed

5 files changed

+25
-14
lines changed

src/replicate/_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
import httpx
1010

11+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
12+
1113
from . import _exceptions
1214
from ._qs import Querystring
1315
from .types import PredictionOutput, PredictionCreateParams
@@ -132,7 +134,7 @@ def run(
132134
ref: str,
133135
*,
134136
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
135-
**params: Unpack[PredictionCreateParams],
137+
**params: Unpack[PredictionCreateParamsWithoutVersion],
136138
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
137139
"""Run a model and wait for its output."""
138140
from .lib._predictions import run

src/replicate/lib/_predictions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import TYPE_CHECKING, Dict, Union, Iterable
44
from typing_extensions import Unpack
55

6+
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
7+
68
from ..types import PredictionOutput, PredictionCreateParams
79
from .._types import NOT_GIVEN, NotGiven
810
from .._utils import is_given
@@ -21,7 +23,7 @@ def run(
2123
*,
2224
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
2325
# use_file_output: Optional[bool] = True,
24-
**params: Unpack[PredictionCreateParams],
26+
**params: Unpack[PredictionCreateParamsWithoutVersion],
2527
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
2628
from ._files import transform_output
2729

@@ -40,7 +42,8 @@ def run(
4042
params.setdefault("prefer", f"wait={wait}")
4143

4244
# TODO: support more ref types
43-
prediction = client.predictions.create(version=ref, **params)
45+
params_with_version: PredictionCreateParams = {**params, "version": ref}
46+
prediction = client.predictions.create(**params_with_version)
4447

4548
# Currently the "Prefer: wait" interface will return a prediction with a status
4649
# of "processing" rather than a terminal state because it returns before the
@@ -91,7 +94,8 @@ async def async_run(
9194
params.setdefault("prefer", f"wait={wait}")
9295

9396
# TODO: support more ref types
94-
prediction = await client.predictions.create(version=ref, **params)
97+
params_with_version: PredictionCreateParams = {**params, "version": ref}
98+
prediction = await client.predictions.create(**params_with_version)
9599

96100
# Currently the "Prefer: wait" interface will return a prediction with a status
97101
# of "processing" rather than a terminal state because it returns before the

src/replicate/resources/predictions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ def with_streaming_response(self) -> PredictionsResourceWithStreamingResponse:
5050

5151
def wait(self, prediction_id: str) -> Prediction:
5252
"""Wait for prediction to finish."""
53-
prediction = self.retrieve(prediction_id)
53+
prediction = self.get(prediction_id)
5454
while prediction.status not in PREDICTION_TERMINAL_STATES:
5555
self._sleep(self._client.poll_interval)
56-
prediction = self.retrieve(prediction.id)
56+
prediction = self.get(prediction.id)
5757
return prediction
5858

5959
def create(
@@ -469,10 +469,10 @@ def with_streaming_response(self) -> AsyncPredictionsResourceWithStreamingRespon
469469

470470
async def wait(self, prediction_id: str) -> Prediction:
471471
"""Wait for prediction to finish."""
472-
prediction = await self.retrieve(prediction_id)
472+
prediction = await self.get(prediction_id)
473473
while prediction.status not in PREDICTION_TERMINAL_STATES:
474474
await self._sleep(self._client.poll_interval)
475-
prediction = await self.retrieve(prediction.id)
475+
prediction = await self.get(prediction.id)
476476
return prediction
477477

478478
async def create(

src/replicate/types/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from .deployment_create_params import DeploymentCreateParams as DeploymentCreateParams
1616
from .deployment_list_response import DeploymentListResponse as DeploymentListResponse
1717
from .deployment_update_params import DeploymentUpdateParams as DeploymentUpdateParams
18-
from .prediction_create_params import PredictionCreateParams as PredictionCreateParams
18+
from .prediction_create_params import (
19+
PredictionCreateParams as PredictionCreateParams,
20+
PredictionCreateParamsWithoutVersion as PredictionCreateParamsWithoutVersion,
21+
)
1922
from .training_cancel_response import TrainingCancelResponse as TrainingCancelResponse
2023
from .deployment_create_response import DeploymentCreateResponse as DeploymentCreateResponse
2124
from .deployment_update_response import DeploymentUpdateResponse as DeploymentUpdateResponse

src/replicate/types/prediction_create_params.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from .._utils import PropertyInfo
99

10-
__all__ = ["PredictionCreateParams"]
10+
__all__ = ["PredictionCreateParams", "PredictionCreateParamsWithoutVersion"]
1111

1212

13-
class PredictionCreateParams(TypedDict, total=False):
13+
class PredictionCreateParamsWithoutVersion(TypedDict, total=False):
1414
input: Required[object]
1515
"""The model's input as a JSON object.
1616
@@ -36,9 +36,6 @@ class PredictionCreateParams(TypedDict, total=False):
3636
- you don't need to use the file again (Replicate will not store it)
3737
"""
3838

39-
version: Required[str]
40-
"""The ID of the model version that you want to run."""
41-
4239
stream: bool
4340
"""**This field is deprecated.**
4441
@@ -94,3 +91,8 @@ class PredictionCreateParams(TypedDict, total=False):
9491
"""
9592

9693
prefer: Annotated[str, PropertyInfo(alias="Prefer")]
94+
95+
96+
class PredictionCreateParams(PredictionCreateParamsWithoutVersion):
97+
version: Required[str]
98+
"""The ID of the model version that you want to run."""

0 commit comments

Comments
 (0)