1111
1212
1313@pytest .fixture
14- def mock_model_response ():
14+ def mock_model_response () -> ModelGetResponse :
1515 """Mock response for model.get requests."""
1616 return ModelGetResponse (
1717 url = "https://replicate.com/stability-ai/stable-diffusion" ,
@@ -33,11 +33,11 @@ class TestModelGetBackwardCompatibility:
3333 """Test backward compatibility for models.get() method."""
3434
3535 @pytest .fixture
36- def client (self ):
36+ def client (self ) -> Replicate :
3737 """Create a Replicate client with mocked token."""
3838 return Replicate (bearer_token = "test-token" )
3939
40- def test_legacy_format_owner_name (self , client , mock_model_response ) :
40+ def test_legacy_format_owner_name (self , client : Replicate , mock_model_response : ModelGetResponse ) -> None :
4141 """Test legacy format: models.get('owner/name')."""
4242 # Mock the underlying _get method
4343 with patch .object (client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -52,7 +52,7 @@ def test_legacy_format_owner_name(self, client, mock_model_response):
5252 )
5353 assert result == mock_model_response
5454
55- def test_new_format_keyword_args (self , client , mock_model_response ) :
55+ def test_new_format_keyword_args (self , client : Replicate , mock_model_response : ModelGetResponse ) -> None :
5656 """Test new format: models.get(model_owner='owner', model_name='name')."""
5757 # Mock the underlying _get method
5858 with patch .object (client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -67,7 +67,7 @@ def test_new_format_keyword_args(self, client, mock_model_response):
6767 )
6868 assert result == mock_model_response
6969
70- def test_legacy_format_with_extra_params (self , client , mock_model_response ) :
70+ def test_legacy_format_with_extra_params (self , client : Replicate , mock_model_response : ModelGetResponse ) -> None :
7171 """Test legacy format with extra parameters."""
7272 # Mock the underlying _get method
7373 with patch .object (client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -80,29 +80,29 @@ def test_legacy_format_with_extra_params(self, client, mock_model_response):
8080 mock_get .assert_called_once ()
8181 assert result == mock_model_response
8282
83- def test_error_mixed_formats (self , client ) :
83+ def test_error_mixed_formats (self , client : Replicate ) -> None :
8484 """Test error when mixing legacy and new formats."""
8585 with pytest .raises (ValueError ) as exc_info :
8686 client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
8787
8888 assert "Cannot specify both positional and keyword arguments" in str (exc_info .value )
8989
90- def test_error_invalid_legacy_format (self , client ) :
90+ def test_error_invalid_legacy_format (self , client : Replicate ) -> None :
9191 """Test error for invalid legacy format (no slash)."""
9292 with pytest .raises (ValueError ) as exc_info :
9393 client .models .get ("invalid-format" )
9494
9595 assert "Invalid model reference 'invalid-format'" in str (exc_info .value )
9696 assert "Expected format: 'owner/name'" in str (exc_info .value )
9797
98- def test_error_missing_parameters (self , client ) :
98+ def test_error_missing_parameters (self , client : Replicate ) -> None :
9999 """Test error when no parameters are provided."""
100100 with pytest .raises (ValueError ) as exc_info :
101101 client .models .get ()
102102
103103 assert "model_owner and model_name are required" in str (exc_info .value )
104104
105- def test_legacy_format_with_complex_names (self , client , mock_model_response ) :
105+ def test_legacy_format_with_complex_names (self , client : Replicate , mock_model_response : ModelGetResponse ) -> None :
106106 """Test legacy format with complex owner/model names."""
107107 # Mock the underlying _get method
108108 with patch .object (client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -115,8 +115,9 @@ def test_legacy_format_with_complex_names(self, client, mock_model_response):
115115 options = {},
116116 cast_to = ModelGetResponse
117117 )
118+ assert result == mock_model_response
118119
119- def test_legacy_format_multiple_slashes (self , client ) :
120+ def test_legacy_format_multiple_slashes (self , client : Replicate ) -> None :
120121 """Test legacy format with multiple slashes (should split on first slash only)."""
121122 # Mock the underlying _get method
122123 with patch .object (client .models , "_get" , return_value = Mock ()) as mock_get :
@@ -135,12 +136,12 @@ class TestAsyncModelGetBackwardCompatibility:
135136 """Test backward compatibility for async models.get() method."""
136137
137138 @pytest .fixture
138- async def async_client (self ):
139+ async def async_client (self ) -> AsyncReplicate :
139140 """Create an async Replicate client with mocked token."""
140141 return AsyncReplicate (bearer_token = "test-token" )
141142
142143 @pytest .mark .asyncio
143- async def test_async_legacy_format_owner_name (self , async_client , mock_model_response ) :
144+ async def test_async_legacy_format_owner_name (self , async_client : AsyncReplicate , mock_model_response : ModelGetResponse ) -> None :
144145 """Test async legacy format: models.get('owner/name')."""
145146 # Mock the underlying _get method
146147 with patch .object (async_client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -156,7 +157,7 @@ async def test_async_legacy_format_owner_name(self, async_client, mock_model_res
156157 assert result == mock_model_response
157158
158159 @pytest .mark .asyncio
159- async def test_async_new_format_keyword_args (self , async_client , mock_model_response ) :
160+ async def test_async_new_format_keyword_args (self , async_client : AsyncReplicate , mock_model_response : ModelGetResponse ) -> None :
160161 """Test async new format: models.get(model_owner='owner', model_name='name')."""
161162 # Mock the underlying _get method
162163 with patch .object (async_client .models , "_get" , return_value = mock_model_response ) as mock_get :
@@ -172,7 +173,7 @@ async def test_async_new_format_keyword_args(self, async_client, mock_model_resp
172173 assert result == mock_model_response
173174
174175 @pytest .mark .asyncio
175- async def test_async_error_mixed_formats (self , async_client ) :
176+ async def test_async_error_mixed_formats (self , async_client : AsyncReplicate ) -> None :
176177 """Test async error when mixing legacy and new formats."""
177178 with pytest .raises (ValueError ) as exc_info :
178179 await async_client .models .get ("stability-ai/stable-diffusion" , model_owner = "other-owner" )
0 commit comments