Skip to content

Commit b969e71

Browse files
committed
add support for more ref types
1 parent 46b1f8d commit b969e71

File tree

4 files changed

+187
-20
lines changed

4 files changed

+187
-20
lines changed

src/replicate/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
UnprocessableEntityError,
4040
APIResponseValidationError,
4141
)
42+
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
4243
from ._base_client import DefaultHttpxClient, DefaultAsyncHttpxClient
4344
from ._utils._logs import setup_logging as _setup_logging
4445

@@ -83,6 +84,9 @@
8384
"DefaultAsyncHttpxClient",
8485
"FileOutput",
8586
"AsyncFileOutput",
87+
"Model",
88+
"Version",
89+
"ModelVersionIdentifier",
8690
]
8791

8892
_setup_logging()

src/replicate/_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import httpx
1010

11+
from replicate.lib._predictions import Model, Version, ModelVersionIdentifier
1112
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
1213

1314
from . import _exceptions
@@ -127,7 +128,7 @@ def __init__(
127128

128129
def run(
129130
self,
130-
ref: str,
131+
ref: Union[Model, Version, ModelVersionIdentifier, str],
131132
*,
132133
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
133134
**params: Unpack[PredictionCreateParamsWithoutVersion],
@@ -322,7 +323,7 @@ def __init__(
322323

323324
async def run(
324325
self,
325-
ref: str,
326+
ref: Union[Model, Version, ModelVersionIdentifier, str],
326327
*,
327328
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
328329
**params: Unpack[PredictionCreateParamsWithoutVersion],

src/replicate/lib/_predictions.py

Lines changed: 107 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,57 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Dict, Union, Iterable
3+
from typing import TYPE_CHECKING, Dict, Union, Iterable, Optional
44
from typing_extensions import Unpack
55

66
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
77

88
from ..types import PredictionOutput, PredictionCreateParams
99
from .._types import NOT_GIVEN, NotGiven
1010
from .._utils import is_given
11-
from .._client import ReplicateClient, AsyncReplicateClient
11+
from ._models import Model, Version, ModelVersionIdentifier, resolve_reference
1212
from .._exceptions import ModelError
1313

1414
if TYPE_CHECKING:
1515
from ._files import FileOutput
16+
from .._client import ReplicateClient, AsyncReplicateClient
1617

1718

1819
def run(
19-
client: ReplicateClient,
20-
ref: str,
21-
# TODO: support these types
22-
# ref: Union["Model", "Version", "ModelVersionIdentifier", str],
20+
client: "ReplicateClient",
21+
ref: Union[Model, Version, ModelVersionIdentifier, str],
2322
*,
2423
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
25-
# use_file_output: Optional[bool] = True,
24+
_use_file_output: Optional[bool] = True,
2625
**params: Unpack[PredictionCreateParamsWithoutVersion],
2726
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
27+
"""
28+
Run a model prediction.
29+
30+
Args:
31+
client: The ReplicateClient instance to use for API calls
32+
ref: Reference to the model or version to run. Can be:
33+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
34+
- A string with owner/name format (e.g. "replicate/hello-world")
35+
- A string with owner/name/version format (e.g. "replicate/hello-world/5c7d5dc6...")
36+
- A Model instance with owner and name attributes
37+
- A Version instance with id attribute
38+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
39+
input: Dictionary of input parameters for the model
40+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
41+
If an integer, wait up to that many seconds.
42+
use_file_output: If True (default), convert output URLs to FileOutput objects
43+
**params: Additional parameters to pass to the prediction creation endpoint
44+
45+
Returns:
46+
The prediction output, which could be a basic type (str, int, etc.), a FileOutput object,
47+
a list of FileOutput objects, or a dictionary of FileOutput objects, depending on what
48+
the model returns.
49+
50+
Raises:
51+
ModelError: If the model run fails
52+
ValueError: If the reference format is invalid
53+
TypeError: If both wait and prefer parameters are provided
54+
"""
2855
from ._files import transform_output
2956

3057
if is_given(wait) and "prefer" in params:
@@ -41,9 +68,27 @@ def run(
4168
else:
4269
params.setdefault("prefer", f"wait={wait}")
4370

44-
# TODO: support more ref types
45-
params_with_version: PredictionCreateParams = {**params, "version": ref}
46-
prediction = client.predictions.create(**params_with_version)
71+
# Resolve ref to its components
72+
_version, owner, name, version_id = resolve_reference(ref)
73+
74+
prediction = None
75+
if version_id is not None:
76+
# Create prediction with the specific version ID
77+
params_with_version: PredictionCreateParams = {**params, "version": version_id}
78+
prediction = client.predictions.create(**params_with_version)
79+
elif owner and name:
80+
# Create prediction via models resource with owner/name
81+
prediction = client.models.predictions.create(model_owner=owner, model_name=name, **params)
82+
else:
83+
# If ref is a string but doesn't match expected patterns
84+
if isinstance(ref, str):
85+
params_with_version = {**params, "version": ref}
86+
prediction = client.predictions.create(**params_with_version)
87+
else:
88+
raise ValueError(
89+
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
90+
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
91+
)
4792

4893
# Currently the "Prefer: wait" interface will return a prediction with a status
4994
# of "processing" rather than a terminal state because it returns before the
@@ -68,15 +113,41 @@ def run(
68113

69114

70115
async def async_run(
71-
client: AsyncReplicateClient,
72-
ref: str,
73-
# TODO: support these types
74-
# ref: Union["Model", "Version", "ModelVersionIdentifier", str],
116+
client: "AsyncReplicateClient",
117+
ref: Union[Model, Version, ModelVersionIdentifier, str],
75118
*,
76119
wait: Union[int, bool, NotGiven] = NOT_GIVEN,
77-
# use_file_output: Optional[bool] = True,
120+
_use_file_output: Optional[bool] = True,
78121
**params: Unpack[PredictionCreateParamsWithoutVersion],
79122
) -> PredictionOutput | FileOutput | Iterable[FileOutput] | Dict[str, FileOutput]:
123+
"""
124+
Run a model prediction asynchronously.
125+
126+
Args:
127+
client: The AsyncReplicateClient instance to use for API calls
128+
ref: Reference to the model or version to run. Can be:
129+
- A string containing a version ID (e.g. "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa")
130+
- A string with owner/name format (e.g. "replicate/hello-world")
131+
- A string with owner/name/version format (e.g. "replicate/hello-world/5c7d5dc6...")
132+
- A Model instance with owner and name attributes
133+
- A Version instance with id attribute
134+
- A ModelVersionIdentifier dictionary with owner, name, and/or version keys
135+
input: Dictionary of input parameters for the model
136+
wait: If True (default), wait for the prediction to complete. If False, return immediately.
137+
If an integer, wait up to that many seconds.
138+
use_file_output: If True (default), convert output URLs to AsyncFileOutput objects
139+
**params: Additional parameters to pass to the prediction creation endpoint
140+
141+
Returns:
142+
The prediction output, which could be a basic type (str, int, etc.), an AsyncFileOutput object,
143+
a list of AsyncFileOutput objects, or a dictionary of AsyncFileOutput objects, depending on what
144+
the model returns.
145+
146+
Raises:
147+
ModelError: If the model run fails
148+
ValueError: If the reference format is invalid
149+
TypeError: If both wait and prefer parameters are provided
150+
"""
80151
from ._files import transform_output
81152

82153
if is_given(wait) and "prefer" in params:
@@ -93,9 +164,27 @@ async def async_run(
93164
else:
94165
params.setdefault("prefer", f"wait={wait}")
95166

96-
# TODO: support more ref types
97-
params_with_version: PredictionCreateParams = {**params, "version": ref}
98-
prediction = await client.predictions.create(**params_with_version)
167+
# Resolve ref to its components
168+
_version, owner, name, version_id = resolve_reference(ref)
169+
170+
prediction = None
171+
if version_id is not None:
172+
# Create prediction with the specific version ID
173+
params_with_version: PredictionCreateParams = {**params, "version": version_id}
174+
prediction = await client.predictions.create(**params_with_version)
175+
elif owner and name:
176+
# Create prediction via models resource with owner/name
177+
prediction = await client.models.predictions.create(model_owner=owner, model_name=name, **params)
178+
else:
179+
# If ref is a string but doesn't match expected patterns
180+
if isinstance(ref, str):
181+
params_with_version = {**params, "version": ref}
182+
prediction = await client.predictions.create(**params_with_version)
183+
else:
184+
raise ValueError(
185+
f"Invalid reference format: {ref}. Expected a model name ('owner/name'), "
186+
"a version ID, a Model object, a Version object, or a ModelVersionIdentifier."
187+
)
99188

100189
# Currently the "Prefer: wait" interface will return a prediction with a status
101190
# of "processing" rather than a terminal state because it returns before the

tests/lib/test_run.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from replicate import ReplicateClient, AsyncReplicateClient
1212
from replicate.lib._files import FileOutput, AsyncFileOutput
1313
from replicate._exceptions import ModelError, NotFoundError, BadRequestError
14+
from replicate.lib._models import Model, Version, ModelVersionIdentifier
1415

1516
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
1617
bearer_token = "My Bearer Token"
@@ -154,6 +155,7 @@ def test_run_with_error(self, respx_mock: MockRouter) -> None:
154155
with pytest.raises(ModelError):
155156
self.client.run("error-model-ref", input={"prompt": "trigger error"})
156157

158+
@pytest.mark.skip("todo: support BytesIO conversion")
157159
@pytest.mark.respx(base_url=base_url)
158160
def test_run_with_base64_file(self, respx_mock: MockRouter) -> None:
159161
"""Test run with base64 encoded file input."""
@@ -205,6 +207,41 @@ def test_run_with_invalid_cog_version(self, respx_mock: MockRouter) -> None:
205207
with pytest.raises(BadRequestError):
206208
self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
207209

210+
@pytest.mark.respx(base_url=base_url)
211+
def test_run_with_model_object(self, respx_mock: MockRouter) -> None:
212+
"""Test run with Model object reference."""
213+
# Mock the models endpoint for owner/name lookup
214+
respx_mock.post("/models/test-owner/test-model/predictions").mock(
215+
return_value=httpx.Response(201, json=create_mock_prediction())
216+
)
217+
218+
model = Model(owner="test-owner", name="test-model")
219+
output = self.client.run(model, input={"prompt": "test prompt"})
220+
221+
assert output == "test output"
222+
223+
@pytest.mark.respx(base_url=base_url)
224+
def test_run_with_version_object(self, respx_mock: MockRouter) -> None:
225+
"""Test run with Version object reference."""
226+
# Version ID is used directly
227+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
228+
229+
version = Version(id="test-version-id")
230+
output = self.client.run(version, input={"prompt": "test prompt"})
231+
232+
assert output == "test output"
233+
234+
@pytest.mark.respx(base_url=base_url)
235+
def test_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None:
236+
"""Test run with ModelVersionIdentifier dict reference."""
237+
# Case where version ID is provided
238+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
239+
240+
identifier: ModelVersionIdentifier = {"owner": "test-owner", "name": "test-model", "version": "test-version-id"}
241+
output = self.client.run(identifier, input={"prompt": "test prompt"})
242+
243+
assert output == "test output"
244+
208245
@pytest.mark.respx(base_url=base_url)
209246
def test_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
210247
"""Test run with file output iterator."""
@@ -349,6 +386,7 @@ async def test_async_run_with_error(self, respx_mock: MockRouter) -> None:
349386
with pytest.raises(ModelError):
350387
await self.client.run("error-model-ref", input={"prompt": "trigger error"})
351388

389+
@pytest.mark.skip("todo: support BytesIO conversion")
352390
@pytest.mark.respx(base_url=base_url)
353391
async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None:
354392
"""Test async run with base64 encoded file input."""
@@ -400,6 +438,41 @@ async def test_async_run_with_invalid_cog_version(self, respx_mock: MockRouter)
400438
with pytest.raises(BadRequestError):
401439
await self.client.run("model-with-invalid-cog", input={"prompt": "test prompt"})
402440

441+
@pytest.mark.respx(base_url=base_url)
442+
async def test_async_run_with_model_object(self, respx_mock: MockRouter) -> None:
443+
"""Test async run with Model object reference."""
444+
# Mock the models endpoint for owner/name lookup
445+
respx_mock.post("/models/test-owner/test-model/predictions").mock(
446+
return_value=httpx.Response(201, json=create_mock_prediction())
447+
)
448+
449+
model = Model(owner="test-owner", name="test-model")
450+
output = await self.client.run(model, input={"prompt": "test prompt"})
451+
452+
assert output == "test output"
453+
454+
@pytest.mark.respx(base_url=base_url)
455+
async def test_async_run_with_version_object(self, respx_mock: MockRouter) -> None:
456+
"""Test async run with Version object reference."""
457+
# Version ID is used directly
458+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
459+
460+
version = Version(id="test-version-id")
461+
output = await self.client.run(version, input={"prompt": "test prompt"})
462+
463+
assert output == "test output"
464+
465+
@pytest.mark.respx(base_url=base_url)
466+
async def test_async_run_with_model_version_identifier(self, respx_mock: MockRouter) -> None:
467+
"""Test async run with ModelVersionIdentifier dict reference."""
468+
# Case where version ID is provided
469+
respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction()))
470+
471+
identifier: ModelVersionIdentifier = {"owner": "test-owner", "name": "test-model", "version": "test-version-id"}
472+
output = await self.client.run(identifier, input={"prompt": "test prompt"})
473+
474+
assert output == "test output"
475+
403476
@pytest.mark.respx(base_url=base_url)
404477
async def test_async_run_with_file_output_iterator(self, respx_mock: MockRouter) -> None:
405478
"""Test async run with file output iterator."""

0 commit comments

Comments
 (0)