Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 127 additions & 6 deletions src/replicate/resources/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from typing import overload
from typing_extensions import Literal

import httpx
Expand Down Expand Up @@ -297,11 +298,39 @@ def delete(
cast_to=NoneType,
)

@overload
def get(
self,
model_owner_and_name: str,
*,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ModelGetResponse:
"""Legacy format: models.get("owner/name")"""
...

@overload
def get(
self,
*,
model_owner: str,
model_name: str,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ModelGetResponse:
"""New format: models.get(model_owner="owner", model_name="name")"""
...

def get(
self,
model_owner_and_name: str | NotGiven = NOT_GIVEN,
*,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -384,15 +413,47 @@ def get(
The `latest_version` object is the model's most recently pushed
[version](#models.versions.get).

Supports both legacy and new formats:
- Legacy: models.get("owner/name")
- New: models.get(model_owner="owner", model_name="name")

Args:
model_owner_and_name: Legacy format string "owner/name" (positional argument)
model_owner: Model owner (keyword argument)
model_name: Model name (keyword argument)
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
# Handle legacy format: models.get("owner/name")
if model_owner_and_name is not NOT_GIVEN:
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
raise ValueError(
"Cannot specify both positional and keyword arguments. "
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
)

# Type narrowing - at this point model_owner_and_name must be a string
if not isinstance(model_owner_and_name, str):
raise TypeError("model_owner_and_name must be a string")

# Parse the owner/name format
if "/" not in model_owner_and_name:
raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'")

parts = model_owner_and_name.split("/", 1)
model_owner = parts[0]
model_name = parts[1]

# Validate required parameters
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
raise ValueError("model_owner and model_name are required")

# Type narrowing - at this point both must be strings
if not isinstance(model_owner, str) or not isinstance(model_name, str):
raise TypeError("model_owner and model_name must be strings")

if not model_owner:
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
if not model_name:
Expand Down Expand Up @@ -696,11 +757,39 @@ async def delete(
cast_to=NoneType,
)

@overload
async def get(
self,
model_owner_and_name: str,
*,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ModelGetResponse:
"""Legacy format: models.get("owner/name")"""
...

@overload
async def get(
self,
*,
model_owner: str,
model_name: str,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ModelGetResponse:
"""New format: models.get(model_owner="owner", model_name="name")"""
...

async def get(
self,
model_owner_and_name: str | NotGiven = NOT_GIVEN,
*,
model_owner: str | NotGiven = NOT_GIVEN,
model_name: str | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -783,15 +872,47 @@ async def get(
The `latest_version` object is the model's most recently pushed
[version](#models.versions.get).

Supports both legacy and new formats:
- Legacy: models.get("owner/name")
- New: models.get(model_owner="owner", model_name="name")

Args:
model_owner_and_name: Legacy format string "owner/name" (positional argument)
model_owner: Model owner (keyword argument)
model_name: Model name (keyword argument)
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
# Handle legacy format: models.get("owner/name")
if model_owner_and_name is not NOT_GIVEN:
if model_owner is not NOT_GIVEN or model_name is not NOT_GIVEN:
raise ValueError(
"Cannot specify both positional and keyword arguments. "
"Use either models.get('owner/name') or models.get(model_owner='owner', model_name='name')"
)

# Type narrowing - at this point model_owner_and_name must be a string
if not isinstance(model_owner_and_name, str):
raise TypeError("model_owner_and_name must be a string")

# Parse the owner/name format
if "/" not in model_owner_and_name:
raise ValueError(f"Invalid model reference '{model_owner_and_name}'. Expected format: 'owner/name'")

parts = model_owner_and_name.split("/", 1)
model_owner = parts[0]
model_name = parts[1]

# Validate required parameters
if model_owner is NOT_GIVEN or model_name is NOT_GIVEN:
raise ValueError("model_owner and model_name are required")

# Type narrowing - at this point both must be strings
if not isinstance(model_owner, str) or not isinstance(model_name, str):
raise TypeError("model_owner and model_name must be strings")

if not model_owner:
raise ValueError(f"Expected a non-empty value for `model_owner` but received {model_owner!r}")
if not model_name:
Expand Down
161 changes: 161 additions & 0 deletions tests/lib/test_models_get_backward_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Tests for models.get() backward compatibility with legacy owner/name format."""

import os

import httpx
import pytest
from respx import MockRouter

from replicate import Replicate, AsyncReplicate

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
bearer_token = "My Bearer Token"


def mock_model_response():
"""Mock model response data."""
return {
"url": "https://replicate.com/stability-ai/stable-diffusion",
"owner": "stability-ai",
"name": "stable-diffusion",
"description": "A model for generating images from text prompts",
"visibility": "public",
"github_url": None,
"paper_url": None,
"license_url": None,
"run_count": 12345,
"cover_image_url": "https://example.com/cover.jpg",
"default_example": None,
"latest_version": None,
}


class TestModelsGetLegacyFormat:
"""Test legacy format: models.get('owner/name')."""

client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)

@pytest.mark.respx(base_url=base_url)
def test_legacy_format_basic(self, respx_mock: MockRouter):
"""Test basic legacy format with owner/name."""
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
return_value=httpx.Response(200, json=mock_model_response())
)

model = self.client.models.get("stability-ai/stable-diffusion")

assert model.owner == "stability-ai"
assert model.name == "stable-diffusion"

@pytest.mark.respx(base_url=base_url)
def test_legacy_format_with_hyphens_and_dots(self, respx_mock: MockRouter):
"""Test legacy format with hyphenated names and dots."""
response_data = {**mock_model_response(), "owner": "black-forest-labs", "name": "flux-1.1-pro"}
respx_mock.get("/models/black-forest-labs/flux-1.1-pro").mock(
return_value=httpx.Response(200, json=response_data)
)

model = self.client.models.get("black-forest-labs/flux-1.1-pro")

assert model.owner == "black-forest-labs"
assert model.name == "flux-1.1-pro"

@pytest.mark.respx(base_url=base_url)
def test_legacy_format_splits_on_first_slash_only(self, respx_mock: MockRouter):
"""Test legacy format splits on first slash only."""
response_data = {**mock_model_response(), "owner": "owner", "name": "name/with/slashes"}
respx_mock.get("/models/owner/name/with/slashes").mock(return_value=httpx.Response(200, json=response_data))

model = self.client.models.get("owner/name/with/slashes")

assert model.owner == "owner"
assert model.name == "name/with/slashes"

def test_legacy_format_error_no_slash(self):
"""Test error when legacy format has no slash."""
with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"):
self.client.models.get("invalid-format")

def test_legacy_format_error_mixed_with_kwargs(self):
"""Test error when mixing positional and keyword arguments."""
with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"):
self.client.models.get("owner/name", model_owner="other-owner") # type: ignore[call-overload]


class TestModelsGetNewFormat:
"""Test new format: models.get(model_owner='owner', model_name='name')."""

client = Replicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)

@pytest.mark.respx(base_url=base_url)
def test_new_format_basic(self, respx_mock: MockRouter):
"""Test basic new format with keyword arguments."""
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
return_value=httpx.Response(200, json=mock_model_response())
)

model = self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion")

assert model.owner == "stability-ai"
assert model.name == "stable-diffusion"

def test_new_format_error_missing_params(self):
"""Test error when required parameters are missing."""
with pytest.raises(ValueError, match="model_owner and model_name are required"):
self.client.models.get() # type: ignore[call-overload]


class TestAsyncModelsGetLegacyFormat:
"""Test async legacy format."""

client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)

@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_async_legacy_format_basic(self, respx_mock: MockRouter):
"""Test async basic legacy format."""
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
return_value=httpx.Response(200, json=mock_model_response())
)

model = await self.client.models.get("stability-ai/stable-diffusion")

assert model.owner == "stability-ai"
assert model.name == "stable-diffusion"

@pytest.mark.asyncio
async def test_async_legacy_format_error_no_slash(self):
"""Test async error when legacy format has no slash."""
with pytest.raises(ValueError, match="Invalid model reference 'invalid-format'.*Expected format: 'owner/name'"):
await self.client.models.get("invalid-format")

@pytest.mark.asyncio
async def test_async_legacy_format_error_mixed(self):
"""Test async error when mixing formats."""
with pytest.raises(ValueError, match="Cannot specify both positional and keyword arguments"):
await self.client.models.get("owner/name", model_owner="other") # type: ignore[call-overload]


class TestAsyncModelsGetNewFormat:
"""Test async new format."""

client = AsyncReplicate(base_url=base_url, bearer_token=bearer_token, _strict_response_validation=True)

@pytest.mark.respx(base_url=base_url)
@pytest.mark.asyncio
async def test_async_new_format_basic(self, respx_mock: MockRouter):
"""Test async new format."""
respx_mock.get("/models/stability-ai/stable-diffusion").mock(
return_value=httpx.Response(200, json=mock_model_response())
)

model = await self.client.models.get(model_owner="stability-ai", model_name="stable-diffusion")

assert model.owner == "stability-ai"
assert model.name == "stable-diffusion"

@pytest.mark.asyncio
async def test_async_new_format_error_missing_params(self):
"""Test async error when required parameters are missing."""
with pytest.raises(ValueError, match="model_owner and model_name are required"):
await self.client.models.get() # type: ignore[call-overload]