Skip to content

Commit bcaaff8

Browse files
committed
feat: add backward compatibility for models.get("owner/name") syntax
This PR adds backward compatibility for the legacy models.get("owner/name") syntax while maintaining full forward compatibility with the new keyword argument format. - Add compatibility layer in lib/models.py that handles both formats - Patch both sync and async ModelsResource instances in client initialization - Support both models.get("stability-ai/stable-diffusion") and models.get(model_owner="stability-ai", model_name="stable-diffusion") - Add comprehensive tests for both syntax formats and error cases - Reduce breaking changes from 4 to 3 areas for easier migration Resolves Linear issue DP-656
1 parent 8c05e64 commit bcaaff8

File tree

3 files changed

+363
-2
lines changed

3 files changed

+363
-2
lines changed

src/replicate/_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,11 @@ def account(self) -> AccountResource:
186186

187187
@cached_property
188188
def models(self) -> ModelsResource:
189+
from .lib.models import patch_models_resource
189190
from .resources.models import ModelsResource
190191

191-
return ModelsResource(self)
192+
models_resource = ModelsResource(self)
193+
return patch_models_resource(models_resource)
192194

193195
@cached_property
194196
def predictions(self) -> PredictionsResource:
@@ -572,9 +574,11 @@ def account(self) -> AsyncAccountResource:
572574

573575
@cached_property
574576
def models(self) -> AsyncModelsResource:
577+
from .lib.models import patch_models_resource
575578
from .resources.models import AsyncModelsResource
576579

577-
return AsyncModelsResource(self)
580+
models_resource = AsyncModelsResource(self)
581+
return patch_models_resource(models_resource)
578582

579583
@cached_property
580584
def predictions(self) -> AsyncPredictionsResource:

src/replicate/lib/models.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Custom models functionality with backward compatibility.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import inspect
8+
from typing import TYPE_CHECKING, Union
9+
10+
from .._types import NOT_GIVEN, NotGiven
11+
from ._models import ModelVersionIdentifier
12+
13+
if TYPE_CHECKING:
14+
import httpx
15+
16+
from .._types import Body, Query, Headers
17+
from ..resources.models.models import ModelsResource, AsyncModelsResource
18+
from ..types.model_get_response import ModelGetResponse
19+
20+
21+
def _parse_model_args(
22+
model_or_owner: str | NotGiven,
23+
model_owner: str | NotGiven,
24+
model_name: str | NotGiven,
25+
) -> tuple[str, str]:
26+
"""Parse model arguments and return (owner, name)."""
27+
# Handle legacy format: models.get("owner/name")
28+
if model_or_owner is not NOT_GIVEN:
29+
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
30+
raise ValueError(
31+
"Cannot specify both positional 'model_or_owner' and keyword arguments "
32+
"'model_owner'/'model_name'. Use either the legacy format "
33+
"models.get('owner/name') or the new format models.get(model_owner='owner', model_name='name')."
34+
)
35+
36+
# Type guard: model_or_owner is definitely a string here
37+
assert isinstance(model_or_owner, str)
38+
39+
# Parse the owner/name format
40+
if "/" not in model_or_owner:
41+
raise ValueError(
42+
f"Invalid model reference '{model_or_owner}'. "
43+
"Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion')"
44+
)
45+
46+
try:
47+
parsed = ModelVersionIdentifier.parse(model_or_owner)
48+
return parsed.owner, parsed.name
49+
except ValueError as e:
50+
raise ValueError(
51+
f"Invalid model reference '{model_or_owner}'. "
52+
f"Expected format: 'owner/name' (e.g., 'stability-ai/stable-diffusion'). "
53+
f"Error: {e}"
54+
) from e
55+
56+
# Validate required parameters for new format
57+
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
58+
raise ValueError(
59+
"model_owner and model_name are required. "
60+
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
61+
)
62+
63+
return model_owner, model_name
64+
65+
66+
def patch_models_resource(
67+
models_resource: Union["ModelsResource", "AsyncModelsResource"],
68+
) -> Union["ModelsResource", "AsyncModelsResource"]:
69+
"""Patch a models resource to add backward compatibility."""
70+
original_get = models_resource.get
71+
is_async = inspect.iscoroutinefunction(original_get)
72+
73+
if is_async:
74+
75+
async def get_wrapper(
76+
model_or_owner: str | NotGiven = NOT_GIVEN,
77+
*,
78+
model_owner: str | NotGiven = NOT_GIVEN,
79+
model_name: str | NotGiven = NOT_GIVEN,
80+
extra_headers: "Headers | None" = None,
81+
extra_query: "Query | None" = None,
82+
extra_body: "Body | None" = None,
83+
timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN,
84+
) -> "ModelGetResponse":
85+
owner, name = _parse_model_args(model_or_owner, model_owner, model_name)
86+
return await original_get(
87+
model_owner=owner,
88+
model_name=name,
89+
extra_headers=extra_headers,
90+
extra_query=extra_query,
91+
extra_body=extra_body,
92+
timeout=timeout,
93+
)
94+
else:
95+
96+
def get_wrapper(
97+
model_or_owner: str | NotGiven = NOT_GIVEN,
98+
*,
99+
model_owner: str | NotGiven = NOT_GIVEN,
100+
model_name: str | NotGiven = NOT_GIVEN,
101+
extra_headers: "Headers | None" = None,
102+
extra_query: "Query | None" = None,
103+
extra_body: "Body | None" = None,
104+
timeout: "float | httpx.Timeout | None | NotGiven" = NOT_GIVEN,
105+
) -> "ModelGetResponse":
106+
owner, name = _parse_model_args(model_or_owner, model_owner, model_name)
107+
return original_get(
108+
model_owner=owner,
109+
model_name=name,
110+
extra_headers=extra_headers,
111+
extra_query=extra_query,
112+
extra_body=extra_body,
113+
timeout=timeout,
114+
)
115+
116+
# Store original method for tests and replace with wrapper
117+
models_resource._original_get = original_get
118+
models_resource.get = get_wrapper
119+
return models_resource
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
"""
2+
Tests for backward compatibility in models.get() method.
3+
"""
4+
5+
from unittest.mock import Mock, AsyncMock, patch
6+
7+
import pytest
8+
9+
from replicate import Replicate, AsyncReplicate
10+
from replicate._types import NOT_GIVEN
11+
from replicate.types.model_get_response import ModelGetResponse
12+
13+
14+
@pytest.fixture
15+
def mock_model_response():
16+
"""Mock response for model.get requests."""
17+
return ModelGetResponse(
18+
url="https://replicate.com/stability-ai/stable-diffusion",
19+
owner="stability-ai",
20+
name="stable-diffusion",
21+
description="A model for generating images from text prompts",
22+
visibility="public",
23+
github_url=None,
24+
paper_url=None,
25+
license_url=None,
26+
run_count=0,
27+
cover_image_url=None,
28+
default_example=None,
29+
latest_version=None,
30+
)
31+
32+
33+
class TestModelGetBackwardCompatibility:
34+
"""Test backward compatibility for models.get() method."""
35+
36+
@pytest.fixture
37+
def client(self):
38+
"""Create a Replicate client with mocked token."""
39+
with patch("replicate.lib.cog._get_api_token_from_environment", return_value="test-token"):
40+
return Replicate()
41+
42+
def test_legacy_format_owner_name(self, client, mock_model_response):
43+
"""Test legacy format: models.get('owner/name')."""
44+
# Mock the original get method
45+
client.models._original_get = Mock(return_value=mock_model_response)
46+
47+
# Call with legacy format
48+
result = client.models.get("stability-ai/stable-diffusion")
49+
50+
# Verify original method was called with correct parameters
51+
client.models._original_get.assert_called_once_with(
52+
model_owner="stability-ai",
53+
model_name="stable-diffusion",
54+
extra_headers=None,
55+
extra_query=None,
56+
extra_body=None,
57+
timeout=NOT_GIVEN,
58+
)
59+
assert result == mock_model_response
60+
61+
def test_new_format_keyword_args(self, client, mock_model_response):
62+
"""Test new format: models.get(model_owner='owner', model_name='name')."""
63+
# Mock the original get method
64+
client.models._original_get = Mock(return_value=mock_model_response)
65+
66+
# Call with new format
67+
result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
68+
69+
# Verify original method was called with correct parameters
70+
client.models._original_get.assert_called_once_with(
71+
model_owner="stability-ai",
72+
model_name="stable-diffusion",
73+
extra_headers=None,
74+
extra_query=None,
75+
extra_body=None,
76+
timeout=NOT_GIVEN,
77+
)
78+
assert result == mock_model_response
79+
80+
def test_legacy_format_with_extra_params(self, client, mock_model_response):
81+
"""Test legacy format with extra parameters."""
82+
# Mock the original get method
83+
client.models._original_get = Mock(return_value=mock_model_response)
84+
85+
# Call with legacy format and extra parameters
86+
result = client.models.get("stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0)
87+
88+
# Verify original method was called with correct parameters
89+
client.models._original_get.assert_called_once_with(
90+
model_owner="stability-ai",
91+
model_name="stable-diffusion",
92+
extra_headers={"X-Custom": "test"},
93+
extra_query=None,
94+
extra_body=None,
95+
timeout=30.0,
96+
)
97+
assert result == mock_model_response
98+
99+
def test_error_mixed_formats(self, client):
100+
"""Test error when mixing legacy and new formats."""
101+
with pytest.raises(ValueError) as exc_info:
102+
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
103+
104+
assert "Cannot specify both positional 'model_or_owner' and keyword arguments" in str(exc_info.value)
105+
106+
def test_error_invalid_legacy_format(self, client):
107+
"""Test error for invalid legacy format (no slash)."""
108+
with pytest.raises(ValueError) as exc_info:
109+
client.models.get("invalid-format")
110+
111+
assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
112+
assert "Expected format: 'owner/name'" in str(exc_info.value)
113+
114+
def test_error_missing_parameters(self, client):
115+
"""Test error when no parameters are provided."""
116+
with pytest.raises(ValueError) as exc_info:
117+
client.models.get()
118+
119+
assert "model_owner and model_name are required" in str(exc_info.value)
120+
121+
def test_legacy_format_with_complex_names(self, client, mock_model_response):
122+
"""Test legacy format with complex owner/model names."""
123+
# Mock the original get method
124+
client.models._original_get = Mock(return_value=mock_model_response)
125+
126+
# Test with hyphenated names and numbers
127+
result = client.models.get("black-forest-labs/flux-1.1-pro")
128+
129+
# Verify parsing
130+
client.models._original_get.assert_called_once_with(
131+
model_owner="black-forest-labs",
132+
model_name="flux-1.1-pro",
133+
extra_headers=None,
134+
extra_query=None,
135+
extra_body=None,
136+
timeout=NOT_GIVEN,
137+
)
138+
139+
def test_legacy_format_multiple_slashes_error(self, client):
140+
"""Test error for legacy format with multiple slashes."""
141+
with pytest.raises(ValueError) as exc_info:
142+
client.models.get("owner/name/version")
143+
144+
assert "Invalid model reference" in str(exc_info.value)
145+
146+
147+
class TestAsyncModelGetBackwardCompatibility:
148+
"""Test backward compatibility for async models.get() method."""
149+
150+
@pytest.fixture
151+
async def async_client(self):
152+
"""Create an async Replicate client with mocked token."""
153+
with patch("replicate.lib.cog._get_api_token_from_environment", return_value="test-token"):
154+
return AsyncReplicate()
155+
156+
@pytest.mark.asyncio
157+
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
158+
"""Test async legacy format: models.get('owner/name')."""
159+
# Mock the original async get method
160+
async_client.models._original_get = AsyncMock(return_value=mock_model_response)
161+
162+
# Call with legacy format
163+
result = await async_client.models.get("stability-ai/stable-diffusion")
164+
165+
# Verify original method was called with correct parameters
166+
async_client.models._original_get.assert_called_once_with(
167+
model_owner="stability-ai",
168+
model_name="stable-diffusion",
169+
extra_headers=None,
170+
extra_query=None,
171+
extra_body=None,
172+
timeout=NOT_GIVEN,
173+
)
174+
assert result == mock_model_response
175+
176+
@pytest.mark.asyncio
177+
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
178+
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
179+
# Mock the original async get method
180+
async_client.models._original_get = AsyncMock(return_value=mock_model_response)
181+
182+
# Call with new format
183+
result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
184+
185+
# Verify original method was called with correct parameters
186+
async_client.models._original_get.assert_called_once_with(
187+
model_owner="stability-ai",
188+
model_name="stable-diffusion",
189+
extra_headers=None,
190+
extra_query=None,
191+
extra_body=None,
192+
timeout=NOT_GIVEN,
193+
)
194+
assert result == mock_model_response
195+
196+
@pytest.mark.asyncio
197+
async def test_async_error_mixed_formats(self, async_client):
198+
"""Test async error when mixing legacy and new formats."""
199+
with pytest.raises(ValueError) as exc_info:
200+
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
201+
202+
assert "Cannot specify both positional 'model_or_owner' and keyword arguments" in str(exc_info.value)
203+
204+
205+
class TestModelVersionIdentifierIntegration:
206+
"""Test integration with ModelVersionIdentifier parsing."""
207+
208+
@pytest.fixture
209+
def client(self):
210+
"""Create a Replicate client with mocked token."""
211+
with patch("replicate.lib.cog._get_api_token_from_environment", return_value="test-token"):
212+
return Replicate()
213+
214+
def test_legacy_format_parsing_edge_cases(self, client, mock_model_response):
215+
"""Test edge cases in legacy format parsing."""
216+
# Mock the original get method
217+
client.models._original_get = Mock(return_value=mock_model_response)
218+
219+
# Test various valid formats
220+
test_cases = [
221+
("owner/name", "owner", "name"),
222+
("owner-with-hyphens/name-with-hyphens", "owner-with-hyphens", "name-with-hyphens"),
223+
("owner123/name456", "owner123", "name456"),
224+
("owner/name.with.dots", "owner", "name.with.dots"),
225+
]
226+
227+
for model_ref, expected_owner, expected_name in test_cases:
228+
client.models._original_get.reset_mock()
229+
client.models.get(model_ref)
230+
231+
client.models._original_get.assert_called_once_with(
232+
model_owner=expected_owner,
233+
model_name=expected_name,
234+
extra_headers=None,
235+
extra_query=None,
236+
extra_body=None,
237+
timeout=NOT_GIVEN,
238+
)

0 commit comments

Comments
 (0)