1+ """
2+ Tests for backward compatibility in models.get() method.
3+ """
4+
5+ from unittest .mock import Mock , patch
6+
7+ import pytest
8+ import httpx
9+
10+ from replicate import Replicate , AsyncReplicate
11+ from replicate ._types import NOT_GIVEN
12+ from replicate .types .model_get_response import ModelGetResponse
13+
14+
15+ @pytest .fixture
16+ def mock_model_response ():
17+ """Mock response for model.get requests."""
18+ return ModelGetResponse (
19+ url = "https://replicate.com/stability-ai/stable-diffusion" ,
20+ owner = "stability-ai" ,
21+ name = "stable-diffusion" ,
22+ description = "A model for generating images from text prompts" ,
23+ visibility = "public" ,
24+ github_url = None ,
25+ paper_url = None ,
26+ license_url = None ,
27+ run_count = 0 ,
28+ cover_image_url = None ,
29+ default_example = None ,
30+ latest_version = None ,
31+ )
32+
33+
34+ class TestModelGetBackwardCompatibility :
35+ """Test backward compatibility for models.get() method."""
36+
37+ @pytest .fixture
38+ def client (self ):
39+ """Create a Replicate client with mocked token."""
40+ return Replicate (bearer_token = "test-token" )
41+
42+ def test_legacy_format_owner_name (self , client , mock_model_response ):
43+ """Test legacy format: models.get('owner/name')."""
44+ # Mock the underlying _get method
45+ with patch .object (client .models , '_get' , return_value = mock_model_response ) as mock_get :
46+ # Call with legacy format
47+ result = client .models .get ("stability-ai/stable-diffusion" )
48+
49+ # Verify underlying method was called with correct parameters
50+ mock_get .assert_called_once_with (
51+ "/models/stability-ai/stable-diffusion" ,
52+ options = Mock ()
53+ )
54+ assert result == mock_model_response
55+
56+ def test_new_format_keyword_args (self , client , mock_model_response ):
57+ """Test new format: models.get(model_owner='owner', model_name='name')."""
58+ # Mock the underlying _get method
59+ with patch .object (client .models , '_get' , return_value = mock_model_response ) as mock_get :
60+ # Call with new format
61+ result = client .models .get (model_owner = "stability-ai" , model_name = "stable-diffusion" )
62+
63+ # Verify underlying method was called with correct parameters
64+ mock_get .assert_called_once_with (
65+ "/models/stability-ai/stable-diffusion" ,
66+ options = Mock ()
67+ )
68+ assert result == mock_model_response
69+
70+ def test_legacy_format_with_extra_params (self , client , mock_model_response ):
71+ """Test legacy format with extra parameters."""
72+ # Mock the underlying _get method
73+ with patch .object (client .models , '_get' , return_value = mock_model_response ) as mock_get :
74+ # Call with legacy format and extra parameters
75+ result = client .models .get (
76+ "stability-ai/stable-diffusion" ,
77+ extra_headers = {"X-Custom" : "test" },
78+ timeout = 30.0
79+ )
80+
81+ # Verify underlying method was called
82+ mock_get .assert_called_once ()
83+ assert result == mock_model_response
84+
85+ def test_error_mixed_formats (self , client ):
86+ """Test error when mixing legacy and new formats."""
87+ with pytest .raises (ValueError ) as exc_info :
88+ client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
89+
90+ assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
91+
92+ def test_error_invalid_legacy_format (self , client ):
93+ """Test error for invalid legacy format (no slash)."""
94+ with pytest .raises (ValueError ) as exc_info :
95+ client .models .get ("invalid-format" )
96+
97+ assert "Invalid model reference 'invalid-format'" in str (exc_info .value )
98+ assert "Expected format: 'owner/name'" in str (exc_info .value )
99+
100+ def test_error_missing_parameters (self , client ):
101+ """Test error when no parameters are provided."""
102+ with pytest .raises (ValueError ) as exc_info :
103+ client .models .get ()
104+
105+ assert "model_owner and model_name are required" in str (exc_info .value )
106+
107+ def test_legacy_format_with_complex_names (self , client , mock_model_response ):
108+ """Test legacy format with complex owner/model names."""
109+ # Mock the underlying _get method
110+ with patch .object (client .models , '_get' , return_value = mock_model_response ) as mock_get :
111+ # Test with hyphenated names and numbers
112+ result = client .models .get ("black-forest-labs/flux-1.1-pro" )
113+
114+ # Verify parsing
115+ mock_get .assert_called_once_with (
116+ "/models/black-forest-labs/flux-1.1-pro" ,
117+ options = Mock ()
118+ )
119+
120+ def test_legacy_format_multiple_slashes (self , client ):
121+ """Test legacy format with multiple slashes (should split on first slash only)."""
122+ # Mock the underlying _get method
123+ with patch .object (client .models , '_get' , return_value = Mock ()) as mock_get :
124+ # This should work - split on first slash only
125+ client .models .get ("owner/name/with/slashes" )
126+
127+ # Verify it was parsed correctly
128+ mock_get .assert_called_once_with (
129+ "/models/owner/name/with/slashes" ,
130+ options = Mock ()
131+ )
132+
133+
134+ class TestAsyncModelGetBackwardCompatibility :
135+ """Test backward compatibility for async models.get() method."""
136+
137+ @pytest .fixture
138+ async def async_client (self ):
139+ """Create an async Replicate client with mocked token."""
140+ return AsyncReplicate (bearer_token = "test-token" )
141+
142+ @pytest .mark .asyncio
143+ async def test_async_legacy_format_owner_name (self , async_client , mock_model_response ):
144+ """Test async legacy format: models.get('owner/name')."""
145+ # Mock the underlying _get method
146+ with patch .object (async_client .models , '_get' , return_value = mock_model_response ) as mock_get :
147+ # Call with legacy format
148+ result = await async_client .models .get ("stability-ai/stable-diffusion" )
149+
150+ # Verify underlying method was called with correct parameters
151+ mock_get .assert_called_once_with (
152+ "/models/stability-ai/stable-diffusion" ,
153+ options = Mock ()
154+ )
155+ assert result == mock_model_response
156+
157+ @pytest .mark .asyncio
158+ async def test_async_new_format_keyword_args (self , async_client , mock_model_response ):
159+ """Test async new format: models.get(model_owner='owner', model_name='name')."""
160+ # Mock the underlying _get method
161+ with patch .object (async_client .models , '_get' , return_value = mock_model_response ) as mock_get :
162+ # Call with new format
163+ result = await async_client .models .get (model_owner = "stability-ai" , model_name = "stable-diffusion" )
164+
165+ # Verify underlying method was called with correct parameters
166+ mock_get .assert_called_once_with (
167+ "/models/stability-ai/stable-diffusion" ,
168+ options = Mock ()
169+ )
170+ assert result == mock_model_response
171+
172+ @pytest .mark .asyncio
173+ async def test_async_error_mixed_formats (self , async_client ):
174+ """Test async error when mixing legacy and new formats."""
175+ with pytest .raises (ValueError ) as exc_info :
176+ await async_client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
177+
178+ assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
0 commit comments