55from unittest .mock import Mock , patch
66
77import pytest
8- import httpx
98
109from replicate import Replicate , AsyncReplicate
11- from replicate ._types import NOT_GIVEN
1210from replicate .types .model_get_response import ModelGetResponse
1311
1412
@@ -42,42 +40,42 @@ def client(self):
4240 def test_legacy_format_owner_name (self , client , mock_model_response ):
4341 """Test legacy format: models.get('owner/name')."""
4442 # Mock the underlying _get method
45- with patch .object (client .models , ' _get' , return_value = mock_model_response ) as mock_get :
43+ with patch .object (client .models , " _get" , return_value = mock_model_response ) as mock_get :
4644 # Call with legacy format
4745 result = client .models .get ("stability-ai/stable-diffusion" )
48-
46+
4947 # Verify underlying method was called with correct parameters
5048 mock_get .assert_called_once_with (
5149 "/models/stability-ai/stable-diffusion" ,
52- options = Mock ()
50+ options = {},
51+ cast_to = ModelGetResponse
5352 )
5453 assert result == mock_model_response
5554
5655 def test_new_format_keyword_args (self , client , mock_model_response ):
5756 """Test new format: models.get(model_owner='owner', model_name='name')."""
5857 # Mock the underlying _get method
59- with patch .object (client .models , ' _get' , return_value = mock_model_response ) as mock_get :
58+ with patch .object (client .models , " _get" , return_value = mock_model_response ) as mock_get :
6059 # Call with new format
6160 result = client .models .get (model_owner = "stability-ai" , model_name = "stable-diffusion" )
62-
61+
6362 # Verify underlying method was called with correct parameters
6463 mock_get .assert_called_once_with (
6564 "/models/stability-ai/stable-diffusion" ,
66- options = Mock ()
65+ options = {},
66+ cast_to = ModelGetResponse
6767 )
6868 assert result == mock_model_response
6969
7070 def test_legacy_format_with_extra_params (self , client , mock_model_response ):
7171 """Test legacy format with extra parameters."""
7272 # Mock the underlying _get method
73- with patch .object (client .models , ' _get' , return_value = mock_model_response ) as mock_get :
73+ with patch .object (client .models , " _get" , return_value = mock_model_response ) as mock_get :
7474 # Call with legacy format and extra parameters
7575 result = client .models .get (
76- "stability-ai/stable-diffusion" ,
77- extra_headers = {"X-Custom" : "test" },
78- timeout = 30.0
76+ "stability-ai/stable-diffusion" , extra_headers = {"X-Custom" : "test" }, timeout = 30.0
7977 )
80-
78+
8179 # Verify underlying method was called
8280 mock_get .assert_called_once ()
8381 assert result == mock_model_response
@@ -86,48 +84,50 @@ def test_error_mixed_formats(self, client):
8684 """Test error when mixing legacy and new formats."""
8785 with pytest .raises (ValueError ) as exc_info :
8886 client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
89-
87+
9088 assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
9189
9290 def test_error_invalid_legacy_format (self , client ):
9391 """Test error for invalid legacy format (no slash)."""
9492 with pytest .raises (ValueError ) as exc_info :
9593 client .models .get ("invalid-format" )
96-
94+
9795 assert "Invalid model reference 'invalid-format'" in str (exc_info .value )
9896 assert "Expected format: 'owner/name'" in str (exc_info .value )
9997
10098 def test_error_missing_parameters (self , client ):
10199 """Test error when no parameters are provided."""
102100 with pytest .raises (ValueError ) as exc_info :
103101 client .models .get ()
104-
102+
105103 assert "model_owner and model_name are required" in str (exc_info .value )
106104
107105 def test_legacy_format_with_complex_names (self , client , mock_model_response ):
108106 """Test legacy format with complex owner/model names."""
109107 # Mock the underlying _get method
110- with patch .object (client .models , ' _get' , return_value = mock_model_response ) as mock_get :
108+ with patch .object (client .models , " _get" , return_value = mock_model_response ) as mock_get :
111109 # Test with hyphenated names and numbers
112110 result = client .models .get ("black-forest-labs/flux-1.1-pro" )
113-
111+
114112 # Verify parsing
115113 mock_get .assert_called_once_with (
116114 "/models/black-forest-labs/flux-1.1-pro" ,
117- options = Mock ()
115+ options = {},
116+ cast_to = ModelGetResponse
118117 )
119118
120119 def test_legacy_format_multiple_slashes (self , client ):
121120 """Test legacy format with multiple slashes (should split on first slash only)."""
122121 # Mock the underlying _get method
123- with patch .object (client .models , ' _get' , return_value = Mock ()) as mock_get :
122+ with patch .object (client .models , " _get" , return_value = Mock ()) as mock_get :
124123 # This should work - split on first slash only
125124 client .models .get ("owner/name/with/slashes" )
126-
125+
127126 # Verify it was parsed correctly
128127 mock_get .assert_called_once_with (
129128 "/models/owner/name/with/slashes" ,
130- options = Mock ()
129+ options = {},
130+ cast_to = ModelGetResponse
131131 )
132132
133133
@@ -143,29 +143,31 @@ async def async_client(self):
143143 async def test_async_legacy_format_owner_name (self , async_client , mock_model_response ):
144144 """Test async legacy format: models.get('owner/name')."""
145145 # Mock the underlying _get method
146- with patch .object (async_client .models , ' _get' , return_value = mock_model_response ) as mock_get :
146+ with patch .object (async_client .models , " _get" , return_value = mock_model_response ) as mock_get :
147147 # Call with legacy format
148148 result = await async_client .models .get ("stability-ai/stable-diffusion" )
149-
149+
150150 # Verify underlying method was called with correct parameters
151151 mock_get .assert_called_once_with (
152152 "/models/stability-ai/stable-diffusion" ,
153- options = Mock ()
153+ options = {},
154+ cast_to = ModelGetResponse
154155 )
155156 assert result == mock_model_response
156157
157158 @pytest .mark .asyncio
158159 async def test_async_new_format_keyword_args (self , async_client , mock_model_response ):
159160 """Test async new format: models.get(model_owner='owner', model_name='name')."""
160161 # Mock the underlying _get method
161- with patch .object (async_client .models , ' _get' , return_value = mock_model_response ) as mock_get :
162+ with patch .object (async_client .models , " _get" , return_value = mock_model_response ) as mock_get :
162163 # Call with new format
163164 result = await async_client .models .get (model_owner = "stability-ai" , model_name = "stable-diffusion" )
164-
165+
165166 # Verify underlying method was called with correct parameters
166167 mock_get .assert_called_once_with (
167168 "/models/stability-ai/stable-diffusion" ,
168- options = Mock ()
169+ options = {},
170+ cast_to = ModelGetResponse
169171 )
170172 assert result == mock_model_response
171173
@@ -174,5 +176,5 @@ async def test_async_error_mixed_formats(self, async_client):
174176 """Test async error when mixing legacy and new formats."""
175177 with pytest .raises (ValueError ) as exc_info :
176178 await async_client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
177-
178- assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
179+
180+ assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
0 commit comments