Skip to content

Commit 4c9e0cb

Browse files
committed
implement models.get("owner/model") support via overloads
1 parent 85bfd92 commit 4c9e0cb

File tree

3 files changed

+244
-168
lines changed

3 files changed

+244
-168
lines changed

src/replicate/resources/models/models.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from typing import overload
56
from typing_extensions import Literal
67

78
import httpx
@@ -297,9 +298,36 @@ def delete(
297298
cast_to=NoneType,
298299
)
299300

301+
@overload
300302
def get(
301303
self,
302-
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
304+
model_owner_and_name: str,
305+
*,
306+
extra_headers: Headers | None = None,
307+
extra_query: Query | None = None,
308+
extra_body: Body | None = None,
309+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
310+
) -> ModelGetResponse:
311+
"""Legacy format: models.get("owner/name")"""
312+
...
313+
314+
@overload
315+
def get(
316+
self,
317+
*,
318+
model_owner: str,
319+
model_name: str,
320+
extra_headers: Headers | None = None,
321+
extra_query: Query | None = None,
322+
extra_body: Body | None = None,
323+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
324+
) -> ModelGetResponse:
325+
"""New format: models.get(model_owner="owner", model_name="name")"""
326+
...
327+
328+
def get(
329+
self,
330+
model_owner_and_name: str | NotGiven = NOT_GIVEN,
303331
*,
304332
model_owner: str | NotGiven = NOT_GIVEN,
305333
model_name: str | NotGiven = NOT_GIVEN,
@@ -390,7 +418,7 @@ def get(
390418
- New: models.get(model_owner="owner", model_name="name")
391419
392420
Args:
393-
model_or_owner: Legacy format string "owner/name" (positional argument)
421+
model_owner_and_name: Legacy format string "owner/name" (positional argument)
394422
model_owner: Model owner (keyword argument)
395423
model_name: Model name (keyword argument)
396424
extra_headers: Send extra headers
@@ -399,25 +427,33 @@ def get(
399427
timeout: Override the client-level default timeout for this request, in seconds
400428
"""
401429
# Handle legacy format: models.get("owner/name")
402-
if model_or_owner is not NOT_GIVEN:
430+
if model_owner_and_name is not NOT_GIVEN:
403431
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
404432
raise ValueError(
405433
"Cannot specify both positional and keyword arguments. "
406434
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
407435
)
408436

437+
# Type narrowing - at this point model_owner_and_name must be a string
438+
if not isinstance(model_owner_and_name, str):
439+
raise TypeError("model_owner_and_name must be a string")
440+
409441
# Parse the owner/name format
410-
if "/" not in model_or_owner:
411-
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
442+
if "/" not in model_owner_and_name:
443+
raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'")
412444

413-
parts = model_or_owner.split("/", 1)
445+
parts = model_owner_and_name.split("/", 1)
414446
model_owner = parts[0]
415447
model_name = parts[1]
416448

417449
# Validate required parameters
418450
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
419451
raise ValueError("model_owner and model_name are required")
420452

453+
# Type narrowing - at this point both must be strings
454+
if not isinstance(model_owner, str) or not isinstance(model_name, str):
455+
raise TypeError("model_owner and model_name must be strings")
456+
421457
if not model_owner:
422458
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
423459
if not model_name:
@@ -721,9 +757,36 @@ async def delete(
721757
cast_to=NoneType,
722758
)
723759

760+
@overload
724761
async def get(
725762
self,
726-
model_or_owner: str | NotGiven = NOT_GIVEN, # Legacy positional arg
763+
model_owner_and_name: str,
764+
*,
765+
extra_headers: Headers | None = None,
766+
extra_query: Query | None = None,
767+
extra_body: Body | None = None,
768+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
769+
) -> ModelGetResponse:
770+
"""Legacy format: models.get("owner/name")"""
771+
...
772+
773+
@overload
774+
async def get(
775+
self,
776+
*,
777+
model_owner: str,
778+
model_name: str,
779+
extra_headers: Headers | None = None,
780+
extra_query: Query | None = None,
781+
extra_body: Body | None = None,
782+
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
783+
) -> ModelGetResponse:
784+
"""New format: models.get(model_owner="owner", model_name="name")"""
785+
...
786+
787+
async def get(
788+
self,
789+
model_owner_and_name: str | NotGiven = NOT_GIVEN,
727790
*,
728791
model_owner: str | NotGiven = NOT_GIVEN,
729792
model_name: str | NotGiven = NOT_GIVEN,
@@ -814,7 +877,7 @@ async def get(
814877
- New: models.get(model_owner="owner", model_name="name")
815878
816879
Args:
817-
model_or_owner: Legacy format string "owner/name" (positional argument)
880+
model_owner_and_name: Legacy format string "owner/name" (positional argument)
818881
model_owner: Model owner (keyword argument)
819882
model_name: Model name (keyword argument)
820883
extra_headers: Send extra headers
@@ -823,25 +886,33 @@ async def get(
823886
timeout: Override the client-level default timeout for this request, in seconds
824887
"""
825888
# Handle legacy format: models.get("owner/name")
826-
if model_or_owner is not NOT_GIVEN:
889+
if model_owner_and_name is not NOT_GIVEN:
827890
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
828891
raise ValueError(
829892
"Cannot specify both positional and keyword arguments. "
830893
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
831894
)
832895

896+
# Type narrowing - at this point model_owner_and_name must be a string
897+
if not isinstance(model_owner_and_name, str):
898+
raise TypeError("model_owner_and_name must be a string")
899+
833900
# Parse the owner/name format
834-
if "/" not in model_or_owner:
835-
raise ValueError(f"Invalid model reference '{model_or_owner}'. Expected format: 'owner/name'")
901+
if "/" not in model_owner_and_name:
902+
raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'")
836903

837-
parts = model_or_owner.split("/", 1)
904+
parts = model_owner_and_name.split("/", 1)
838905
model_owner = parts[0]
839906
model_name = parts[1]
840907

841908
# Validate required parameters
842909
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
843910
raise ValueError("model_owner and model_name are required")
844911

912+
# Type narrowing - at this point both must be strings
913+
if not isinstance(model_owner, str) or not isinstance(model_name, str):
914+
raise TypeError("model_owner and model_name must be strings")
915+
845916
if not model_owner:
846917
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
847918
if not model_name:
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""Tests for models.get() backward compatibility with legacy owner/name format."""
2+
3+
import os
4+
5+
import httpx
6+
import pytest
7+
from respx import MockRouter
8+
9+
from replicate import Replicate, AsyncReplicate
10+
11+
base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
12+
bearer_token = "My Bearer Token"
13+
14+
15+
def mock_model_response():
16+
"""Mock model response data."""
17+
return {
18+
"url": "https://replicate.com/stability-ai/stable-diffusion",
19+
"owner": "stability-ai",
20+
"name": "stable-diffusion",
21+
"description": "A model for generating images from text prompts",
22+
"visibility": "public",
23+
"github_url": None,
24+
"paper_url": None,
25+
"license_url": None,
26+
"run_count": 12345,
27+
"cover_image_url": "https://example.com/cover.jpg",
28+
"default_example": None,
29+
"latest_version": None,
30+
}
31+
32+
33+
class TestModelsGetLegacyFormat:
34+
"""Test legacy format: models.get('owner/name')."""
35+
36+
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
37+
38+
@pytest.mark.respx(base_url=base_url)
39+
def test_legacy_format_basic(self, respx_mock: MockRouter):
40+
"""Test basic legacy format with owner/name."""
41+
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
42+
return_value=httpx.Response(200, json=mock_model_response())
43+
)
44+
45+
model = self.client.models.get("stability-ai/stable-diffusion")
46+
47+
assert model.owner == "stability-ai"
48+
assert model.name == "stable-diffusion"
49+
50+
@pytest.mark.respx(base_url=base_url)
51+
def test_legacy_format_with_hyphens_and_dots(self, respx_mock: MockRouter):
52+
"""Test legacy format with hyphenated names and dots."""
53+
response_data = {**mock_model_response(), "owner": "black-forest-labs", "name": "flux-1.1-pro"}
54+
respx_mock.get("/models/black-forest-labs/flux-1.1-pro").mock(
55+
return_value=httpx.Response(200, json=response_data)
56+
)
57+
58+
model = self.client.models.get("black-forest-labs/flux-1.1-pro")
59+
60+
assert model.owner == "black-forest-labs"
61+
assert model.name == "flux-1.1-pro"
62+
63+
@pytest.mark.respx(base_url=base_url)
64+
def test_legacy_format_splits_on_first_slash_only(self, respx_mock: MockRouter):
65+
"""Test legacy format splits on first slash only."""
66+
response_data = {**mock_model_response(), "owner": "owner", "name": "name/with/slashes"}
67+
respx_mock.get("/models/owner/name/with/slashes").mock(return_value=httpx.Response(200, json=response_data))
68+
69+
model = self.client.models.get("owner/name/with/slashes")
70+
71+
assert model.owner == "owner"
72+
assert model.name == "name/with/slashes"
73+
74+
def test_legacy_format_error_no_slash(self):
75+
"""Test error when legacy format has no slash."""
76+
with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"):
77+
self.client.models.get("invalid-format")
78+
79+
def test_legacy_format_error_mixed_with_kwargs(self):
80+
"""Test error when mixing positional and keyword arguments."""
81+
with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"):
82+
self.client.models.get("owner/name", model_owner="other-owner") # type: ignore[call-overload]
83+
84+
85+
class TestModelsGetNewFormat:
86+
"""Test new format: models.get(model_owner='owner', model_name='name')."""
87+
88+
client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
89+
90+
@pytest.mark.respx(base_url=base_url)
91+
def test_new_format_basic(self, respx_mock: MockRouter):
92+
"""Test basic new format with keyword arguments."""
93+
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
94+
return_value=httpx.Response(200, json=mock_model_response())
95+
)
96+
97+
model = self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
98+
99+
assert model.owner == "stability-ai"
100+
assert model.name == "stable-diffusion"
101+
102+
def test_new_format_error_missing_params(self):
103+
"""Test error when required parameters are missing."""
104+
with pytest.raises(ValueError, match="model_owner and model_name are required"):
105+
self.client.models.get() # type: ignore[call-overload]
106+
107+
108+
class TestAsyncModelsGetLegacyFormat:
109+
"""Test async legacy format."""
110+
111+
client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
112+
113+
@pytest.mark.respx(base_url=base_url)
114+
@pytest.mark.asyncio
115+
async def test_async_legacy_format_basic(self, respx_mock: MockRouter):
116+
"""Test async basic legacy format."""
117+
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
118+
return_value=httpx.Response(200, json=mock_model_response())
119+
)
120+
121+
model = await self.client.models.get("stability-ai/stable-diffusion")
122+
123+
assert model.owner == "stability-ai"
124+
assert model.name == "stable-diffusion"
125+
126+
@pytest.mark.asyncio
127+
async def test_async_legacy_format_error_no_slash(self):
128+
"""Test async error when legacy format has no slash."""
129+
with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"):
130+
await self.client.models.get("invalid-format")
131+
132+
@pytest.mark.asyncio
133+
async def test_async_legacy_format_error_mixed(self):
134+
"""Test async error when mixing formats."""
135+
with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"):
136+
await self.client.models.get("owner/name", model_owner="other") # type: ignore[call-overload]
137+
138+
139+
class TestAsyncModelsGetNewFormat:
140+
"""Test async new format."""
141+
142+
client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)
143+
144+
@pytest.mark.respx(base_url=base_url)
145+
@pytest.mark.asyncio
146+
async def test_async_new_format_basic(self, respx_mock: MockRouter):
147+
"""Test async new format."""
148+
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
149+
return_value=httpx.Response(200, json=mock_model_response())
150+
)
151+
152+
model = await self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion")
153+
154+
assert model.owner == "stability-ai"
155+
assert model.name == "stable-diffusion"
156+
157+
@pytest.mark.asyncio
158+
async def test_async_new_format_error_missing_params(self):
159+
"""Test async error when required parameters are missing."""
160+
with pytest.raises(ValueError, match="model_owner and model_name are required"):
161+
await self.client.models.get() # type: ignore[call-overload]

0 commit comments

Comments
 (0)