Skip to content

Commit 579b9a3

Browse files
committed
test: fix linting and test failures for backward compatibility features
- Remove unused imports (os, httpx, NOT_GIVEN) from test files - Fix import ordering in test files - Update test assertions to match actual _get method signature - Fix whitespace formatting in models.py - All tests now pass for api_token and models backward compatibility
1 parent e9d836d commit 579b9a3

File tree

3 files changed

+41
-40
lines changed

3 files changed

+41
-40
lines changed

src/replicate/resources/models/models.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,19 @@ def get(
405405
"Cannot specify both positional and keyword arguments. "
406406
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
407407
)
408-
408+
409409
# Parse the owner/name format
410410
if "/" not in model_or_owner:
411411
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
412-
412+
413413
parts = model_or_owner.split("/", 1)
414414
model_owner = parts[0]
415415
model_name = parts[1]
416-
416+
417417
# Validate required parameters
418418
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
419419
raise ValueError("model_owner and model_name are required")
420-
420+
421421
if not model_owner:
422422
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
423423
if not model_name:
@@ -829,19 +829,19 @@ async def get(
829829
"Cannot specify both positional and keyword arguments. "
830830
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
831831
)
832-
832+
833833
# Parse the owner/name format
834834
if "/" not in model_or_owner:
835835
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
836-
836+
837837
parts = model_or_owner.split("/", 1)
838838
model_owner = parts[0]
839839
model_name = parts[1]
840-
840+
841841
# Validate required parameters
842842
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
843843
raise ValueError("model_owner and model_name are required")
844-
844+
845845
if not model_owner:
846846
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
847847
if not model_name:

tests/test_api_token_compatibility.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
import os
65
import pytest
76

87
from replicate import Replicate, AsyncReplicate, ReplicateError
@@ -86,4 +85,4 @@ def test_bearer_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> No
8685
"""Test that explicit bearer_token overrides environment variable."""
8786
monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token")
8887
client = Replicate(bearer_token="explicit_token")
89-
assert client.bearer_token == "explicit_token"
88+
assert client.bearer_token == "explicit_token"

tests/test_models_backward_compat.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,8 @@
55
from unittest.mock import Mock, patch
66

77
import pytest
8-
import httpx
98

109
from replicate import Replicate, AsyncReplicate
11-
from replicate._types import NOT_GIVEN
1210
from replicate.types.model_get_response import ModelGetResponse
1311

1412

@@ -42,42 +40,42 @@ def client(self):
4240
def test_legacy_format_owner_name(self, client, mock_model_response):
4341
"""Test legacy format: models.get('owner/name')."""
4442
# Mock the underlying _get method
45-
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
43+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
4644
# Call with legacy format
4745
result = client.models.get("stability-ai/stable-diffusion")
48-
46+
4947
# Verify underlying method was called with correct parameters
5048
mock_get.assert_called_once_with(
5149
"/models/stability-ai/stable-diffusion",
52-
options=Mock()
50+
options={},
51+
cast_to=ModelGetResponse
5352
)
5453
assert result == mock_model_response
5554

5655
def test_new_format_keyword_args(self, client, mock_model_response):
5756
"""Test new format: models.get(model_owner='owner', model_name='name')."""
5857
# Mock the underlying _get method
59-
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
58+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
6059
# Call with new format
6160
result = client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
62-
61+
6362
# Verify underlying method was called with correct parameters
6463
mock_get.assert_called_once_with(
6564
"/models/stability-ai/stable-diffusion",
66-
options=Mock()
65+
options={},
66+
cast_to=ModelGetResponse
6767
)
6868
assert result == mock_model_response
6969

7070
def test_legacy_format_with_extra_params(self, client, mock_model_response):
7171
"""Test legacy format with extra parameters."""
7272
# Mock the underlying _get method
73-
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
73+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
7474
# Call with legacy format and extra parameters
7575
result = client.models.get(
76-
"stability-ai/stable-diffusion",
77-
extra_headers={"X-Custom": "test"},
78-
timeout=30.0
76+
"stability-ai/stable-diffusion", extra_headers={"X-Custom": "test"}, timeout=30.0
7977
)
80-
78+
8179
# Verify underlying method was called
8280
mock_get.assert_called_once()
8381
assert result == mock_model_response
@@ -86,48 +84,50 @@ def test_error_mixed_formats(self, client):
8684
"""Test error when mixing legacy and new formats."""
8785
with pytest.raises(ValueError) as exc_info:
8886
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
89-
87+
9088
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
9189

9290
def test_error_invalid_legacy_format(self, client):
9391
"""Test error for invalid legacy format (no slash)."""
9492
with pytest.raises(ValueError) as exc_info:
9593
client.models.get("invalid-format")
96-
94+
9795
assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
9896
assert "Expected format: 'owner/name'" in str(exc_info.value)
9997

10098
def test_error_missing_parameters(self, client):
10199
"""Test error when no parameters are provided."""
102100
with pytest.raises(ValueError) as exc_info:
103101
client.models.get()
104-
102+
105103
assert "model_owner and model_name are required" in str(exc_info.value)
106104

107105
def test_legacy_format_with_complex_names(self, client, mock_model_response):
108106
"""Test legacy format with complex owner/model names."""
109107
# Mock the underlying _get method
110-
with patch.object(client.models, '_get', return_value=mock_model_response) as mock_get:
108+
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
111109
# Test with hyphenated names and numbers
112110
result = client.models.get("black-forest-labs/flux-1.1-pro")
113-
111+
114112
# Verify parsing
115113
mock_get.assert_called_once_with(
116114
"/models/black-forest-labs/flux-1.1-pro",
117-
options=Mock()
115+
options={},
116+
cast_to=ModelGetResponse
118117
)
119118

120119
def test_legacy_format_multiple_slashes(self, client):
121120
"""Test legacy format with multiple slashes (should split on first slash only)."""
122121
# Mock the underlying _get method
123-
with patch.object(client.models, '_get', return_value=Mock()) as mock_get:
122+
with patch.object(client.models, "_get", return_value=Mock()) as mock_get:
124123
# This should work - split on first slash only
125124
client.models.get("owner/name/with/slashes")
126-
125+
127126
# Verify it was parsed correctly
128127
mock_get.assert_called_once_with(
129128
"/models/owner/name/with/slashes",
130-
options=Mock()
129+
options={},
130+
cast_to=ModelGetResponse
131131
)
132132

133133

@@ -143,29 +143,31 @@ async def async_client(self):
143143
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
144144
"""Test async legacy format: models.get('owner/name')."""
145145
# Mock the underlying _get method
146-
with patch.object(async_client.models, '_get', return_value=mock_model_response) as mock_get:
146+
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
147147
# Call with legacy format
148148
result = await async_client.models.get("stability-ai/stable-diffusion")
149-
149+
150150
# Verify underlying method was called with correct parameters
151151
mock_get.assert_called_once_with(
152152
"/models/stability-ai/stable-diffusion",
153-
options=Mock()
153+
options={},
154+
cast_to=ModelGetResponse
154155
)
155156
assert result == mock_model_response
156157

157158
@pytest.mark.asyncio
158159
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
159160
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
160161
# Mock the underlying _get method
161-
with patch.object(async_client.models, '_get', return_value=mock_model_response) as mock_get:
162+
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
162163
# Call with new format
163164
result = await async_client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
164-
165+
165166
# Verify underlying method was called with correct parameters
166167
mock_get.assert_called_once_with(
167168
"/models/stability-ai/stable-diffusion",
168-
options=Mock()
169+
options={},
170+
cast_to=ModelGetResponse
169171
)
170172
assert result == mock_model_response
171173

@@ -174,5 +176,5 @@ async def test_async_error_mixed_formats(self, async_client):
174176
"""Test async error when mixing legacy and new formats."""
175177
with pytest.raises(ValueError) as exc_info:
176178
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
177-
178-
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
179+
180+
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)

0 commit comments

Comments
 (0)