Skip to content

Commit e9d836d

Browse files
committed
feat: add simpler backward compatibility for models.get("owner/name")
This commit implements backward compatibility for the legacy models.get() syntax by directly modifying the generated code instead of using a complex patching system. Changes: - Modified both sync and async get() methods in models.py to accept an optional positional argument for legacy "owner/name" format - Added logic to parse the legacy format and convert to keyword args - Added comprehensive tests for both formats - No runtime patching or method wrapping needed The simpler approach: - 50% less code than the patching approach (~80 lines vs 150+) - Much easier to understand and maintain - No runtime overhead - Same functionality with clearer implementation
1 parent 8c05e64 commit e9d836d

File tree

2 files changed

+238
-10
lines changed

2 files changed

+238
-10
lines changed

src/replicate/resources/models/models.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,10 @@ def delete(
299299

300300
def get(
301301
self,
302+
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
302303
*,
303-
model_owner: str,
304-
model_name: str,
304+
model_owner: str | NotGiven = NOT_GIVEN,
305+
model_name: str | NotGiven = NOT_GIVEN,
305306
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
306307
# The extra values given here take precedence over values defined on the client or passed to this method.
307308
extra_headers: Headers | None = None,
@@ -384,15 +385,39 @@ def get(
384385
The `latest_version` object is the model's most recently pushed
385386
[version](#models.versions.get).
386387
388+
Supports both legacy and new formats:
389+
- Legacy: models.get("owner/name")
390+
- New: models.get(model_owner="owner", model_name="name")
391+
387392
Args:
393+
model_or_owner: Legacy format string "owner/name" (positional argument)
394+
model_owner: Model owner (keyword argument)
395+
model_name: Model name (keyword argument)
388396
extra_headers: Send extra headers
389-
390397
extra_query: Add additional query parameters to the request
391-
392398
extra_body: Add additional JSON properties to the request
393-
394399
timeout: Override the client-level default timeout for this request, in seconds
395400
"""
401+
# Handle legacy format: models.get("owner/name")
402+
if model_or_owner is not NOT_GIVEN:
403+
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
404+
raise ValueError(
405+
"Cannot specify both positional and keyword arguments. "
406+
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
407+
)
408+
409+
# Parse the owner/name format
410+
if "/" not in model_or_owner:
411+
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
412+
413+
parts = model_or_owner.split("/", 1)
414+
model_owner = parts[0]
415+
model_name = parts[1]
416+
417+
# Validate required parameters
418+
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
419+
raise ValueError("model_owner and model_name are required")
420+
396421
if not model_owner:
397422
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
398423
if not model_name:
@@ -698,9 +723,10 @@ async def delete(
698723

699724
async def get(
700725
self,
726+
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
701727
*,
702-
model_owner: str,
703-
model_name: str,
728+
model_owner: str | NotGiven = NOT_GIVEN,
729+
model_name: str | NotGiven = NOT_GIVEN,
704730
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
705731
# The extra values given here take precedence over values defined on the client or passed to this method.
706732
extra_headers: Headers | None = None,
@@ -783,15 +809,39 @@ async def get(
783809
The `latest_version` object is the model's most recently pushed
784810
[version](#models.versions.get).
785811
812+
Supports both legacy and new formats:
813+
- Legacy: models.get("owner/name")
814+
- New: models.get(model_owner="owner", model_name="name")
815+
786816
Args:
817+
model_or_owner: Legacy format string "owner/name" (positional argument)
818+
model_owner: Model owner (keyword argument)
819+
model_name: Model name (keyword argument)
787820
extra_headers: Send extra headers
788-
789821
extra_query: Add additional query parameters to the request
790-
791822
extra_body: Add additional JSON properties to the request
792-
793823
timeout: Override the client-level default timeout for this request, in seconds
794824
"""
825+
# Handle legacy format: models.get("owner/name")
826+
if model_or_owner is not NOT_GIVEN:
827+
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
828+
raise ValueError(
829+
"Cannot specify both positional and keyword arguments. "
830+
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
831+
)
832+
833+
# Parse the owner/name format
834+
if "/" not in model_or_owner:
835+
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
836+
837+
parts = model_or_owner.split("/", 1)
838+
model_owner = parts[0]
839+
model_name = parts[1]
840+
841+
# Validate required parameters
842+
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
843+
raise ValueError("model_owner and model_name are required")
844+
795845
if not model_owner:
796846
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
797847
if not model_name:
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
Tests for backward compatibility in models.get() method.
3+
"""
4+
5+
from unittest.mock import Mock, patch
6+
7+
import pytest
8+
import httpx
9+
10+
from replicate import Replicate, AsyncReplicate
11+
from replicate._types import NOT_GIVEN
12+
from replicate.types.model_get_response import ModelGetResponse
13+
14+
15+
@pytest.fixture
16+
def mock_model_response():
17+
"""Mock response for model.get requests."""
18+
return ModelGetResponse(
19+
url="https://replicate.com/stability-ai/stable-diffusion",
20+
owner="stability-ai",
21+
name="stable-diffusion",
22+
description="A model for generating images from text prompts",
23+
visibility="public",
24+
github_url=None,
25+
paper_url=None,
26+
license_url=None,
27+
run_count=0,
28+
cover_image_url=None,
29+
default_example=None,
30+
latest_version=None,
31+
)
32+
33+
34+
class TestModelGetBackwardCompatibility:
35+
"""Test backward compatibility for models.get() method."""
36+
37+
@pytest.fixture
38+
def client(self):
39+
"""Create a Replicate client with mocked token."""
40+
return Replicate(bearer_token="test-token")
41+
42+
def test_legacy_format_owner_name(self, client, mock_model_response):
43+
"""Test legacy format: models.get('owner/name')."""
44+
# Mock the underlying _get method
45+
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
46+
# Call with legacy format
47+
result = client.models.get("stability-ai/stable-diffusion")
48+
49+
# Verify underlying method was called with correct parameters
50+
mock_get.assert_called_once_with(
51+
"/models/stability-ai/stable-diffusion",
52+
options=Mock()
53+
)
54+
assert result == mock_model_response
55+
56+
def test_new_format_keyword_args(self, client, mock_model_response):
57+
"""Test new format: models.get(model_owner='owner', model_name='name')."""
58+
# Mock the underlying _get method
59+
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
60+
# Call with new format
61+
result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
62+
63+
# Verify underlying method was called with correct parameters
64+
mock_get.assert_called_once_with(
65+
"/models/stability-ai/stable-diffusion",
66+
options=Mock()
67+
)
68+
assert result == mock_model_response
69+
70+
def test_legacy_format_with_extra_params(self, client, mock_model_response):
71+
"""Test legacy format with extra parameters."""
72+
# Mock the underlying _get method
73+
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
74+
# Call with legacy format and extra parameters
75+
result = client.models.get(
76+
"stability-ai/stable-diffusion",
77+
extra_headers={"X-Custom": "test"},
78+
timeout=30.0
79+
)
80+
81+
# Verify underlying method was called
82+
mock_get.assert_called_once()
83+
assert result == mock_model_response
84+
85+
def test_error_mixed_formats(self, client):
86+
"""Test error when mixing legacy and new formats."""
87+
with pytest.raises(ValueError) as exc_info:
88+
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
89+
90+
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
91+
92+
def test_error_invalid_legacy_format(self, client):
93+
"""Test error for invalid legacy format (no slash)."""
94+
with pytest.raises(ValueError) as exc_info:
95+
client.models.get("invalid-format")
96+
97+
assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
98+
assert "Expected format: 'owner/name'" in str(exc_info.value)
99+
100+
def test_error_missing_parameters(self, client):
101+
"""Test error when no parameters are provided."""
102+
with pytest.raises(ValueError) as exc_info:
103+
client.models.get()
104+
105+
assert "model_owner and model_name are required" in str(exc_info.value)
106+
107+
def test_legacy_format_with_complex_names(self, client, mock_model_response):
108+
"""Test legacy format with complex owner/model names."""
109+
# Mock the underlying _get method
110+
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
111+
# Test with hyphenated names and numbers
112+
result = client.models.get("black-forest-labs/flux-1.1-pro")
113+
114+
# Verify parsing
115+
mock_get.assert_called_once_with(
116+
"/models/black-forest-labs/flux-1.1-pro",
117+
options=Mock()
118+
)
119+
120+
def test_legacy_format_multiple_slashes(self, client):
121+
"""Test legacy format with multiple slashes (should split on first slash only)."""
122+
# Mock the underlying _get method
123+
with patch.object(client.models, '_get', return_value=Mock()) as mock_get:
124+
# This should work - split on first slash only
125+
client.models.get("owner/name/with/slashes")
126+
127+
# Verify it was parsed correctly
128+
mock_get.assert_called_once_with(
129+
"/models/owner/name/with/slashes",
130+
options=Mock()
131+
)
132+
133+
134+
class TestAsyncModelGetBackwardCompatibility:
135+
"""Test backward compatibility for async models.get() method."""
136+
137+
@pytest.fixture
138+
async def async_client(self):
139+
"""Create an async Replicate client with mocked token."""
140+
return AsyncReplicate(bearer_token="test-token")
141+
142+
@pytest.mark.asyncio
143+
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
144+
"""Test async legacy format: models.get('owner/name')."""
145+
# Mock the underlying _get method
146+
with patch.object(async_client.models, '_get', return_value=mock_model_response) as mock_get:
147+
# Call with legacy format
148+
result = await async_client.models.get("stability-ai/stable-diffusion")
149+
150+
# Verify underlying method was called with correct parameters
151+
mock_get.assert_called_once_with(
152+
"/models/stability-ai/stable-diffusion",
153+
options=Mock()
154+
)
155+
assert result == mock_model_response
156+
157+
@pytest.mark.asyncio
158+
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
159+
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
160+
# Mock the underlying _get method
161+
with patch.object(async_client.models, '_get', return_value=mock_model_response) as mock_get:
162+
# Call with new format
163+
result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
164+
165+
# Verify underlying method was called with correct parameters
166+
mock_get.assert_called_once_with(
167+
"/models/stability-ai/stable-diffusion",
168+
options=Mock()
169+
)
170+
assert result == mock_model_response
171+
172+
@pytest.mark.asyncio
173+
async def test_async_error_mixed_formats(self, async_client):
174+
"""Test async error when mixing legacy and new formats."""
175+
with pytest.raises(ValueError) as exc_info:
176+
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
177+
178+
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)

0 commit comments

Comments
 (0)