Skip to content

Commit c7216db

Browse files
committed
clean up helpers to match underlying api changes
1 parent 00970c0 commit c7216db

File tree

4 files changed

+11
-14
lines changed

4 files changed

+11
-14
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ dependencies = [
1414
"anyio>=3.5.0, <5",
1515
"distro>=1.7.0, <2",
1616
"sniffio",
17-
"asyncio>=3.4.3",
1817
]
1918
requires-python = ">= 3.8"
2019
classifiers = [

src/replicate/lib/_files.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
import httpx
1212

13-
from replicate.types.prediction_output import PredictionOutput
14-
1513
from .._utils import is_mapping, is_sequence
1614

1715
# Use TYPE_CHECKING to avoid circular imports
@@ -218,7 +216,7 @@ def __repr__(self) -> str:
218216
return f'{self.__class__.__name__}("{self.url}")'
219217

220218

221-
def transform_output(value: PredictionOutput, client: "Replicate | AsyncReplicate") -> Any:
219+
def transform_output(value: object, client: "Replicate | AsyncReplicate") -> Any:
222220
"""
223221
Transform the output of a prediction to a `FileOutput` object if it's a URL.
224222
"""

src/replicate/lib/_predictions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from replicate.types.prediction import Prediction
1111
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
1212

13-
from ..types import PredictionOutput, PredictionCreateParams
13+
from ..types import PredictionCreateParams
1414
from .._types import NOT_GIVEN, NotGiven
1515
from .._utils import is_given
1616
from ._models import Model, Version, ModelVersionIdentifier, resolve_reference
@@ -29,7 +29,7 @@ def run(
2929
use_file_output: Optional[bool] = True,
3030
file_encoding_strategy: Optional["FileEncodingStrategy"] = None,
3131
**params: Unpack[PredictionCreateParamsWithoutVersion],
32-
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
32+
) -> object | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
3333
from ._files import transform_output
3434

3535
if is_given(wait) and "prefer" in params:
@@ -91,7 +91,7 @@ def run(
9191

9292
# Return an iterator for the completed prediction when needed.
9393
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
94-
return (transform_output(chunk, client) for chunk in prediction.output)
94+
return (transform_output(chunk, client) for chunk in prediction.output) # type: ignore
9595

9696
if use_file_output:
9797
return transform_output(prediction.output, client) # type: ignore[no-any-return]
@@ -107,7 +107,7 @@ async def async_run(
107107
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
108108
use_file_output: Optional[bool] = True,
109109
**params: Unpack[PredictionCreateParamsWithoutVersion],
110-
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
110+
) -> object | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
111111
from ._files import transform_output
112112

113113
if is_given(wait) and "prefer" in params:
@@ -176,7 +176,7 @@ async def async_run(
176176

177177
# Return an iterator for completed output if the model has an output iterator array type.
178178
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
179-
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output))
179+
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output)) # type: ignore
180180
if use_file_output:
181181
return transform_output(prediction.output, client) # type: ignore[no-any-return]
182182

@@ -207,7 +207,7 @@ def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
207207
yield from new_output
208208
previous_output = output
209209
time.sleep(client.poll_interval)
210-
prediction = client.predictions.get(prediction.id)
210+
prediction = client.predictions.get(prediction_id=prediction.id)
211211

212212
if prediction.status == "failed":
213213
raise ModelError(prediction=prediction)

src/replicate/resources/predictions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def with_streaming_response(self) -> PredictionsResourceWithStreamingResponse:
5252

5353
def wait(self, prediction_id: str) -> Prediction:
5454
"""Wait for prediction to finish."""
55-
prediction = self.get(prediction_id)
55+
prediction = self.get(prediction_id=prediction_id)
5656
while prediction.status not in PREDICTION_TERMINAL_STATES:
5757
self._sleep(self._client.poll_interval)
58-
prediction = self.get(prediction.id)
58+
prediction = self.get(prediction_id=prediction.id)
5959
return prediction
6060

6161
def create(
@@ -479,10 +479,10 @@ def with_streaming_response(self) -> AsyncPredictionsResourceWithStreamingRespon
479479

480480
async def wait(self, prediction_id: str) -> Prediction:
481481
"""Wait for prediction to finish."""
482-
prediction = await self.get(prediction_id)
482+
prediction = await self.get(prediction_id=prediction_id)
483483
while prediction.status not in PREDICTION_TERMINAL_STATES:
484484
await self._sleep(self._client.poll_interval)
485-
prediction = await self.get(prediction.id)
485+
prediction = await self.get(prediction_id=prediction.id)
486486
return prediction
487487

488488
async def create(

0 commit comments

Comments
 (0)