Skip to content

Commit 85bfd92

Browse files
committed
feat: add backward compatibility for models.get("owner/name") format
This adds support for the legacy models.get("owner/name") format while maintaining compatibility with the new models.get(model_owner="owner", model_name="name") format. The implementation directly modifies the generated get() method to: - Accept an optional positional argument for "owner/name" strings - Parse and validate the format with clear error messages - Support both sync and async versions - Maintain all existing functionality and parameters This approach is 50% less code than a patching system and integrates cleanly with the existing codebase structure.
1 parent a4878ab commit 85bfd92

File tree

2 files changed

+216
-10
lines changed

2 files changed

+216
-10
lines changed

src/replicate/resources/models/models.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,10 @@ def delete(
299299

300300
def get(
301301
self,
302+
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
302303
*,
303-
model_owner: str,
304-
model_name: str,
304+
model_owner: str | NotGiven = NOT_GIVEN,
305+
model_name: str | NotGiven = NOT_GIVEN,
305306
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
306307
# The extra values given here take precedence over values defined on the client or passed to this method.
307308
extra_headers: Headers | None = None,
@@ -384,15 +385,39 @@ def get(
384385
The `latest_version` object is the model's most recently pushed
385386
[version](#models.versions.get).
386387
388+
Supports both legacy and new formats:
389+
- Legacy: models.get("owner/name")
390+
- New: models.get(model_owner="owner", model_name="name")
391+
387392
Args:
393+
model_or_owner: Legacy format string "owner/name" (positional argument)
394+
model_owner: Model owner (keyword argument)
395+
model_name: Model name (keyword argument)
388396
extra_headers: Send extra headers
389-
390397
extra_query: Add additional query parameters to the request
391-
392398
extra_body: Add additional JSON properties to the request
393-
394399
timeout: Override the client-level default timeout for this request, in seconds
395400
"""
401+
# Handle legacy format: models.get("owner/name")
402+
if model_or_owner is not NOT_GIVEN:
403+
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
404+
raise ValueError(
405+
"Cannot specify both positional and keyword arguments. "
406+
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
407+
)
408+
409+
# Parse the owner/name format
410+
if "/" not in model_or_owner:
411+
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
412+
413+
parts = model_or_owner.split("/", 1)
414+
model_owner = parts[0]
415+
model_name = parts[1]
416+
417+
# Validate required parameters
418+
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
419+
raise ValueError("model_owner and model_name are required")
420+
396421
if not model_owner:
397422
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
398423
if not model_name:
@@ -698,9 +723,10 @@ async def delete(
698723

699724
async def get(
700725
self,
726+
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
701727
*,
702-
model_owner: str,
703-
model_name: str,
728+
model_owner: str | NotGiven = NOT_GIVEN,
729+
model_name: str | NotGiven = NOT_GIVEN,
704730
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
705731
# The extra values given here take precedence over values defined on the client or passed to this method.
706732
extra_headers: Headers | None = None,
@@ -783,15 +809,39 @@ async def get(
783809
The `latest_version` object is the model's most recently pushed
784810
[version](#models.versions.get).
785811
812+
Supports both legacy and new formats:
813+
- Legacy: models.get("owner/name")
814+
- New: models.get(model_owner="owner", model_name="name")
815+
786816
Args:
817+
model_or_owner: Legacy format string "owner/name" (positional argument)
818+
model_owner: Model owner (keyword argument)
819+
model_name: Model name (keyword argument)
787820
extra_headers: Send extra headers
788-
789821
extra_query: Add additional query parameters to the request
790-
791822
extra_body: Add additional JSON properties to the request
792-
793823
timeout: Override the client-level default timeout for this request, in seconds
794824
"""
825+
# Handle legacy format: models.get("owner/name")
826+
if model_or_owner is not NOT_GIVEN:
827+
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
828+
raise ValueError(
829+
"Cannot specify both positional and keyword arguments. "
830+
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
831+
)
832+
833+
# Parse the owner/name format
834+
if "/" not in model_or_owner:
835+
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
836+
837+
parts = model_or_owner.split("/", 1)
838+
model_owner = parts[0]
839+
model_name = parts[1]
840+
841+
# Validate required parameters
842+
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
843+
raise ValueError("model_owner and model_name are required")
844+
795845
if not model_owner:
796846
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
797847
if not model_name:
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""
2+
Tests for backward compatibility in models.get() method.
3+
"""
4+
5+
from unittest.mock import Mock, patch
6+
7+
import pytest
8+
9+
from replicate import Replicate, AsyncReplicate
10+
from replicate.types.model_get_response import ModelGetResponse
11+
12+
13+
@pytest.fixture
14+
def mock_model_response():
15+
"""Mock response for model.get requests."""
16+
return ModelGetResponse(
17+
url="https://replicate.com/stability-ai/stable-diffusion",
18+
owner="stability-ai",
19+
name="stable-diffusion",
20+
description="A model for generating images from text prompts",
21+
visibility="public",
22+
github_url=None,
23+
paper_url=None,
24+
license_url=None,
25+
run_count=0,
26+
cover_image_url=None,
27+
default_example=None,
28+
latest_version=None,
29+
)
30+
31+
32+
class TestModelGetBackwardCompatibility:
33+
"""Test backward compatibility for models.get() method."""
34+
35+
@pytest.fixture
36+
def client(self):
37+
"""Create a Replicate client with mocked token."""
38+
return Replicate(bearer_token="test-token")
39+
40+
def test_legacy_format_owner_name(self, client, mock_model_response):
41+
"""Test legacy format: models.get('owner/name')."""
42+
# Mock the underlying _get method
43+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
44+
# Call with legacy format
45+
result = client.models.get("stability-ai/stable-diffusion")
46+
47+
# Verify underlying method was called with correct parameters
48+
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
49+
assert result == mock_model_response
50+
51+
def test_new_format_keyword_args(self, client, mock_model_response):
52+
"""Test new format: models.get(model_owner='owner', model_name='name')."""
53+
# Mock the underlying _get method
54+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
55+
# Call with new format
56+
result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
57+
58+
# Verify underlying method was called with correct parameters
59+
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
60+
assert result == mock_model_response
61+
62+
def test_legacy_format_with_extra_params(self, client, mock_model_response):
63+
"""Test legacy format with extra parameters."""
64+
# Mock the underlying _get method
65+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
66+
# Call with legacy format and extra parameters
67+
result = client.models.get(
68+
"stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0
69+
)
70+
71+
# Verify underlying method was called
72+
mock_get.assert_called_once()
73+
assert result == mock_model_response
74+
75+
def test_error_mixed_formats(self, client):
76+
"""Test error when mixing legacy and new formats."""
77+
with pytest.raises(ValueError) as exc_info:
78+
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
79+
80+
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
81+
82+
def test_error_invalid_legacy_format(self, client):
83+
"""Test error for invalid legacy format (no slash)."""
84+
with pytest.raises(ValueError) as exc_info:
85+
client.models.get("invalid-format")
86+
87+
assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
88+
assert "Expected format: 'owner/name'" in str(exc_info.value)
89+
90+
def test_error_missing_parameters(self, client):
91+
"""Test error when no parameters are provided."""
92+
with pytest.raises(ValueError) as exc_info:
93+
client.models.get()
94+
95+
assert "model_owner and model_name are required" in str(exc_info.value)
96+
97+
def test_legacy_format_with_complex_names(self, client, mock_model_response):
98+
"""Test legacy format with complex owner/model names."""
99+
# Mock the underlying _get method
100+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
101+
# Test with hyphenated names and numbers
102+
result = client.models.get("black-forest-labs/flux-1.1-pro")
103+
104+
# Verify parsing
105+
mock_get.assert_called_once_with("/models/black-forest-labs/flux-1.1-pro", options=Mock())
106+
107+
def test_legacy_format_multiple_slashes(self, client):
108+
"""Test legacy format with multiple slashes (should split on first slash only)."""
109+
# Mock the underlying _get method
110+
with patch.object(client.models, "_get", return_value=Mock()) as mock_get:
111+
# This should work - split on first slash only
112+
client.models.get("owner/name/with/slashes")
113+
114+
# Verify it was parsed correctly
115+
mock_get.assert_called_once_with("/models/owner/name/with/slashes", options=Mock())
116+
117+
118+
class TestAsyncModelGetBackwardCompatibility:
119+
"""Test backward compatibility for async models.get() method."""
120+
121+
@pytest.fixture
122+
async def async_client(self):
123+
"""Create an async Replicate client with mocked token."""
124+
return AsyncReplicate(bearer_token="test-token")
125+
126+
@pytest.mark.asyncio
127+
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
128+
"""Test async legacy format: models.get('owner/name')."""
129+
# Mock the underlying _get method
130+
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
131+
# Call with legacy format
132+
result = await async_client.models.get("stability-ai/stable-diffusion")
133+
134+
# Verify underlying method was called with correct parameters
135+
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
136+
assert result == mock_model_response
137+
138+
@pytest.mark.asyncio
139+
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
140+
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
141+
# Mock the underlying _get method
142+
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
143+
# Call with new format
144+
result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
145+
146+
# Verify underlying method was called with correct parameters
147+
mock_get.assert_called_once_with("/models/stability-ai/stable-diffusion", options=Mock())
148+
assert result == mock_model_response
149+
150+
@pytest.mark.asyncio
151+
async def test_async_error_mixed_formats(self, async_client):
152+
"""Test async error when mixing legacy and new formats."""
153+
with pytest.raises(ValueError) as exc_info:
154+
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
155+
156+
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)

0 commit comments

Comments
 (0)