Skip to content

Commit 02934fc

Browse files
committed
fix: add type guards and annotations for models.get() backward compatibility
- Add type guard (assert isinstance) for model_or_owner string check in both sync and async models.get() methods - Add proper type annotations to all test method parameters in test_models_backward_compat.py - Fix unused variable warning in test by adding assertion - Resolves pyright type checking errors while maintaining backward compatibility
1 parent 579b9a3 commit 02934fc

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

src/replicate/resources/models/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,9 @@ def get(
406406
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
407407
)
408408

409+
# Type guard: ensure model_or_owner is a string
410+
assert isinstance(model_or_owner, str), "model_or_owner must be a string"
411+
409412
# Parse the owner/name format
410413
if "/" not in model_or_owner:
411414
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
@@ -830,6 +833,9 @@ async def get(
830833
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
831834
)
832835

836+
# Type guard: ensure model_or_owner is a string
837+
assert isinstance(model_or_owner, str), "model_or_owner must be a string"
838+
833839
# Parse the owner/name format
834840
if "/" not in model_or_owner:
835841
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")

tests/test_models_backward_compat.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@pytest.fixture
14-
def mock_model_response():
14+
def mock_model_response() -> ModelGetResponse:
1515
"""Mock response for model.get requests."""
1616
return ModelGetResponse(
1717
url="https://replicate.com/stability-ai/stable-diffusion",
@@ -33,11 +33,11 @@ class TestModelGetBackwardCompatibility:
3333
"""Test backward compatibility for models.get() method."""
3434

3535
@pytest.fixture
36-
def client(self):
36+
def client(self) -> Replicate:
3737
"""Create a Replicate client with mocked token."""
3838
return Replicate(bearer_token="test-token")
3939

40-
def test_legacy_format_owner_name(self, client, mock_model_response):
40+
def test_legacy_format_owner_name(self, client: Replicate, mock_model_response: ModelGetResponse) -> None:
4141
"""Test legacy format: models.get('owner/name')."""
4242
# Mock the underlying _get method
4343
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -52,7 +52,7 @@ def test_legacy_format_owner_name(self, client, mock_model_response):
5252
)
5353
assert result == mock_model_response
5454

55-
def test_new_format_keyword_args(self, client, mock_model_response):
55+
def test_new_format_keyword_args(self, client: Replicate, mock_model_response: ModelGetResponse) -> None:
5656
"""Test new format: models.get(model_owner='owner', model_name='name')."""
5757
# Mock the underlying _get method
5858
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -67,7 +67,7 @@ def test_new_format_keyword_args(self, client, mock_model_response):
6767
)
6868
assert result == mock_model_response
6969

70-
def test_legacy_format_with_extra_params(self, client, mock_model_response):
70+
def test_legacy_format_with_extra_params(self, client: Replicate, mock_model_response: ModelGetResponse) -> None:
7171
"""Test legacy format with extra parameters."""
7272
# Mock the underlying _get method
7373
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -80,29 +80,29 @@ def test_legacy_format_with_extra_params(self, client, mock_model_response):
8080
mock_get.assert_called_once()
8181
assert result == mock_model_response
8282

83-
def test_error_mixed_formats(self, client):
83+
def test_error_mixed_formats(self, client: Replicate) -> None:
8484
"""Test error when mixing legacy and new formats."""
8585
with pytest.raises(ValueError) as exc_info:
8686
client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")
8787

8888
assert "Cannot specify both positional and keyword arguments" in str(exc_info.value)
8989

90-
def test_error_invalid_legacy_format(self, client):
90+
def test_error_invalid_legacy_format(self, client: Replicate) -> None:
9191
"""Test error for invalid legacy format (no slash)."""
9292
with pytest.raises(ValueError) as exc_info:
9393
client.models.get("invalid-format")
9494

9595
assert "Invalid model reference 'invalid-format'" in str(exc_info.value)
9696
assert "Expected format: 'owner/name'" in str(exc_info.value)
9797

98-
def test_error_missing_parameters(self, client):
98+
def test_error_missing_parameters(self, client: Replicate) -> None:
9999
"""Test error when no parameters are provided."""
100100
with pytest.raises(ValueError) as exc_info:
101101
client.models.get()
102102

103103
assert "model_owner and model_name are required" in str(exc_info.value)
104104

105-
def test_legacy_format_with_complex_names(self, client, mock_model_response):
105+
def test_legacy_format_with_complex_names(self, client: Replicate, mock_model_response: ModelGetResponse) -> None:
106106
"""Test legacy format with complex owner/model names."""
107107
# Mock the underlying _get method
108108
with patch.object(client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -115,8 +115,9 @@ def test_legacy_format_with_complex_names(self, client, mock_model_response):
115115
options={},
116116
cast_to=ModelGetResponse
117117
)
118+
assert result == mock_model_response
118119

119-
def test_legacy_format_multiple_slashes(self, client):
120+
def test_legacy_format_multiple_slashes(self, client: Replicate) -> None:
120121
"""Test legacy format with multiple slashes (should split on first slash only)."""
121122
# Mock the underlying _get method
122123
with patch.object(client.models, "_get", return_value=Mock()) as mock_get:
@@ -135,12 +136,12 @@ class TestAsyncModelGetBackwardCompatibility:
135136
"""Test backward compatibility for async models.get() method."""
136137

137138
@pytest.fixture
138-
async def async_client(self):
139+
async def async_client(self) -> AsyncReplicate:
139140
"""Create an async Replicate client with mocked token."""
140141
return AsyncReplicate(bearer_token="test-token")
141142

142143
@pytest.mark.asyncio
143-
async def test_async_legacy_format_owner_name(self, async_client, mock_model_response):
144+
async def test_async_legacy_format_owner_name(self, async_client: AsyncReplicate, mock_model_response: ModelGetResponse) -> None:
144145
"""Test async legacy format: models.get('owner/name')."""
145146
# Mock the underlying _get method
146147
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -156,7 +157,7 @@ async def test_async_legacy_format_owner_name(self, async_client, mock_model_res
156157
assert result == mock_model_response
157158

158159
@pytest.mark.asyncio
159-
async def test_async_new_format_keyword_args(self, async_client, mock_model_response):
160+
async def test_async_new_format_keyword_args(self, async_client: AsyncReplicate, mock_model_response: ModelGetResponse) -> None:
160161
"""Test async new format: models.get(model_owner='owner', model_name='name')."""
161162
# Mock the underlying _get method
162163
with patch.object(async_client.models, "_get", return_value=mock_model_response) as mock_get:
@@ -172,7 +173,7 @@ async def test_async_new_format_keyword_args(self, async_client, mock_model_resp
172173
assert result == mock_model_response
173174

174175
@pytest.mark.asyncio
175-
async def test_async_error_mixed_formats(self, async_client):
176+
async def test_async_error_mixed_formats(self, async_client: AsyncReplicate) -> None:
176177
"""Test async error when mixing legacy and new formats."""
177178
with pytest.raises(ValueError) as exc_info:
178179
await async_client.models.get("stability-ai/stable-diffusion", model_owner="other-owner")

0 commit comments

Comments
 (0)