Skip to content

Commit cf383c6

Browse files
committed
finish last bits of todos
1 parent eadfb53 commit cf383c6

File tree

5 files changed

+124
-24
lines changed

5 files changed

+124
-24
lines changed

src/replicate/_client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def with_raw_response(self) -> ReplicateWithRawResponse:
167167
@cached_property
168168
def with_streaming_response(self) -> ReplicateWithStreamedResponse:
169169
return ReplicateWithStreamedResponse(self)
170-
170+
171171
@cached_property
172172
def poll_interval(self) -> float:
173173
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
@@ -191,7 +191,7 @@ def default_headers(self) -> dict[str, str | Omit]:
191191
"X-Stainless-Async": "false",
192192
**self._custom_headers,
193193
}
194-
194+
195195
def run(
196196
self,
197197
ref: Union[Model, Version, ModelVersionIdentifier, str],
@@ -408,7 +408,7 @@ def with_raw_response(self) -> AsyncReplicateWithRawResponse:
408408
@cached_property
409409
def with_streaming_response(self) -> AsyncReplicateWithStreamedResponse:
410410
return AsyncReplicateWithStreamedResponse(self)
411-
411+
412412
@cached_property
413413
def poll_interval(self) -> float:
414414
return float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5"))
@@ -432,7 +432,7 @@ def default_headers(self) -> dict[str, str | Omit]:
432432
"X-Stainless-Async": f"async:{get_async_library()}",
433433
**self._custom_headers,
434434
}
435-
435+
436436
async def run(
437437
self,
438438
ref: Union[Model, Version, ModelVersionIdentifier, str],

src/replicate/_module_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class PredictionsResourceProxy(LazyProxy["PredictionsResource"]):
6666
def __load__(self) -> PredictionsResource:
6767
return _load_client().predictions
6868

69+
6970
if TYPE_CHECKING:
7071
from ._client import Replicate
7172

@@ -74,6 +75,7 @@ def __load__(self) -> PredictionsResource:
7475
__client: Replicate = cast(Replicate, {})
7576
run = __client.run
7677
else:
78+
7779
def _run(*args, **kwargs):
7880
return _load_client().run(*args, **kwargs)
7981

src/replicate/lib/_models.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Tuple, Union, Optional
3+
from typing import Any, Dict, Tuple, Union, Optional
44
from typing_extensions import TypedDict
55

66

@@ -12,11 +12,27 @@ def __init__(self, owner: str, name: str):
1212
self.name = name
1313

1414

15-
class Version:
16-
"""A specific version of a Replicate model."""
15+
import datetime
1716

18-
def __init__(self, id: str):
19-
self.id = id
17+
from pydantic import BaseModel
18+
19+
20+
class Version(BaseModel):
21+
"""
22+
A version of a model.
23+
"""
24+
25+
id: str
26+
"""The unique ID of the version."""
27+
28+
created_at: datetime.datetime
29+
"""When the version was created."""
30+
31+
cog_version: str
32+
"""The version of the Cog used to create the version."""
33+
34+
openapi_schema: Dict[str, Any]
35+
"""An OpenAPI description of the model inputs and outputs."""
2036

2137

2238
class ModelVersionIdentifier(TypedDict, total=False):
@@ -29,7 +45,7 @@ class ModelVersionIdentifier(TypedDict, total=False):
2945

3046
def resolve_reference(
3147
ref: Union[Model, Version, ModelVersionIdentifier, str],
32-
) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
48+
) -> Tuple[Optional[Version], Optional[str], Optional[str], Optional[str]]:
3349
"""
3450
Resolve a reference to a model or version to its components.
3551

src/replicate/lib/_predictions.py

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Dict, Union, Iterable, Optional
3+
import time
4+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Iterator, Optional
5+
from collections.abc import AsyncIterator
46
from typing_extensions import Unpack
57

68
from replicate.lib._files import FileEncodingStrategy
9+
from replicate.lib._schema import make_schema_backwards_compatible
10+
from replicate.types.prediction import Prediction
711
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
812

913
from ..types import PredictionOutput, PredictionCreateParams
@@ -71,7 +75,7 @@ def run(
7175
params.setdefault("prefer", f"wait={wait}")
7276

7377
# Resolve ref to its components
74-
_version, owner, name, version_id = resolve_reference(ref)
78+
version, owner, name, version_id = resolve_reference(ref)
7579

7680
prediction = None
7781
if version_id is not None:
@@ -104,14 +108,18 @@ def run(
104108
# "processing".
105109
in_terminal_state = is_blocking and prediction.status != "starting"
106110
if not in_terminal_state:
107-
# TODO: Return a "polling" iterator if the model has an output iterator array type.
111+
# Return a "polling" iterator if the model has an output iterator array type.
112+
if version and _has_output_iterator_array_type(version):
113+
return (transform_output(chunk, client) for chunk in output_iterator(prediction=prediction, client=client))
108114

109115
prediction = client.predictions.wait(prediction.id)
110116

111117
if prediction.status == "failed":
112118
raise ModelError(prediction)
113119

114-
# TODO: Return an iterator for completed output if the model has an output iterator array type.
120+
# Return an iterator for the completed prediction when needed.
121+
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
122+
return (transform_output(chunk, client) for chunk in prediction.output)
115123

116124
if use_file_output:
117125
return transform_output(prediction.output, client) # type: ignore[no-any-return]
@@ -173,7 +181,7 @@ async def async_run(
173181
params.setdefault("prefer", f"wait={wait}")
174182

175183
# Resolve ref to its components
176-
_version, owner, name, version_id = resolve_reference(ref)
184+
version, owner, name, version_id = resolve_reference(ref)
177185

178186
prediction = None
179187
if version_id is not None:
@@ -210,16 +218,56 @@ async def async_run(
210218
# "processing".
211219
in_terminal_state = is_blocking and prediction.status != "starting"
212220
if not in_terminal_state:
213-
# TODO: Return a "polling" iterator if the model has an output iterator array type.
221+
# Return a "polling" iterator if the model has an output iterator array type.
222+
# if version and _has_output_iterator_array_type(version):
223+
# return (
224+
# transform_output(chunk, client)
225+
# async for chunk in prediction.async_output_iterator()
226+
# )
214227

215228
prediction = await client.predictions.wait(prediction.id)
216229

217230
if prediction.status == "failed":
218231
raise ModelError(prediction)
219232

220-
# TODO: Return an iterator for completed output if the model has an output iterator array type.
221-
233+
# Return an iterator for completed output if the model has an output iterator array type.
234+
if version and _has_output_iterator_array_type(version) and prediction.output is not None:
235+
return (transform_output(chunk, client) async for chunk in _make_async_iterator(prediction.output))
222236
if use_file_output:
223237
return transform_output(prediction.output, client) # type: ignore[no-any-return]
224238

225239
return prediction.output
240+
241+
242+
def _has_output_iterator_array_type(version: Version) -> bool:
243+
schema = make_schema_backwards_compatible(version.openapi_schema, version.cog_version)
244+
output = schema.get("components", {}).get("schemas", {}).get("Output", {})
245+
return output.get("type") == "array" and output.get("x-cog-array-type") == "iterator" # type: ignore[no-any-return]
246+
247+
248+
async def _make_async_iterator(list: List[Any]) -> AsyncIterator[Any]:
249+
for item in list:
250+
yield item
251+
252+
253+
def output_iterator(prediction: Prediction, client: Replicate) -> Iterator[Any]:
254+
"""
255+
Return an iterator of the prediction output.
256+
"""
257+
258+
# TODO: check output is list
259+
previous_output: Any = prediction.output or []
260+
while prediction.status not in ["succeeded", "failed", "canceled"]:
261+
output: Any = prediction.output or []
262+
new_output = output[len(previous_output) :]
263+
yield from new_output
264+
previous_output = output
265+
time.sleep(client.poll_interval)
266+
prediction = client.predictions.get(prediction.id)
267+
268+
if prediction.status == "failed":
269+
raise ModelError(prediction=prediction)
270+
271+
output = prediction.output or []
272+
new_output = output[len(previous_output) :]
273+
yield from new_output

tests/lib/test_run.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io
44
import os
5+
import datetime
56
from typing import Any, Dict, Optional
67

78
import httpx
@@ -48,6 +49,41 @@ def create_mock_prediction(
4849
}
4950

5051

52+
def _version_with_schema(id: str = "v1", output_schema: Optional[object] = None) -> Version:
53+
return Version(
54+
id=id,
55+
created_at=datetime.datetime.fromisoformat("2022-03-16T00:35:56.210272"),
56+
cog_version="dev",
57+
openapi_schema={
58+
"openapi": "3.0.2",
59+
"info": {"title": "Cog", "version": "0.1.0"},
60+
"paths": {},
61+
"components": {
62+
"schemas": {
63+
"Input": {
64+
"type": "object",
65+
"title": "Input",
66+
"required": ["text"],
67+
"properties": {
68+
"text": {
69+
"type": "string",
70+
"title": "Text",
71+
"x-order": 0,
72+
"description": "The text input",
73+
},
74+
},
75+
},
76+
"Output": output_schema
77+
or {
78+
"type": "string",
79+
"title": "Output",
80+
},
81+
}
82+
},
83+
},
84+
)
85+
86+
5187
class TestRun:
5288
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
5389

@@ -227,7 +263,7 @@ def test_run_with_version_object(self, respx_mock: MockRouter) -> None:
227263
# Version ID is used directly
228264
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
229265

230-
version = Version(id="test-version-id")
266+
version = _version_with_schema("test-version-id")
231267
output = self.client.run(version, input={"prompt": "test prompt"})
232268

233269
assert output == "test output"
@@ -243,7 +279,6 @@ def test_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None
243279

244280
assert output == "test output"
245281

246-
@pytest.mark.skip("todo: support file output iterator")
247282
@pytest.mark.respx(base_url=base_url)
248283
def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
249284
"""Test run with file output iterator."""
@@ -270,7 +305,7 @@ def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
270305
)
271306

272307
output: list[FileOutput] = self.client.run(
273-
"some-model-ref", use_file_output=True, input={"prompt": "generate file iterator"}
308+
"some-model-ref", use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
274309
)
275310

276311
assert isinstance(output, list)
@@ -460,7 +495,7 @@ async def test_async_run_with_version_object(self, respx_mock: MockRouter) -> No
460495
# Version ID is used directly
461496
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
462497

463-
version = Version(id="test-version-id")
498+
version = _version_with_schema("test-version-id")
464499
output = await self.client.run(version, input={"prompt": "test prompt"})
465500

466501
assert output == "test output"
@@ -476,7 +511,6 @@ async def test_async_run_with_model_version_identifier(self, respx_mock: MockRou
476511

477512
assert output == "test output"
478513

479-
@pytest.mark.skip("todo: support file output iterator")
480514
@pytest.mark.respx(base_url=base_url)
481515
async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
482516
"""Test async run with file output iterator."""
@@ -503,7 +537,7 @@ async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter)
503537
)
504538

505539
output: list[AsyncFileOutput] = await self.client.run(
506-
"some-model-ref", use_file_output=True, input={"prompt": "generate file iterator"}
540+
"some-model-ref", use_file_output=True, wait=False, input={"prompt": "generate file iterator"}
507541
)
508542

509543
assert isinstance(output, list)

0 commit comments

Comments
 (0)