diff --git a/requirements-dev.lock b/requirements-dev.lock index 839ba5d..bebf15e 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -78,9 +78,10 @@ mypy-extensions==1.0.0 nodeenv==1.8.0 # via pyright nox==2023.4.22 -packaging==23.2 +packaging==25.0 # via nox # via pytest + # via replicate platformdirs==3.11.0 # via virtualenv pluggy==1.5.0 diff --git a/requirements.lock b/requirements.lock index 9c126e0..9a52dad 100644 --- a/requirements.lock +++ b/requirements.lock @@ -52,6 +52,8 @@ idna==3.4 multidict==6.4.4 # via aiohttp # via yarl +packaging==25.0 + # via replicate propcache==0.3.1 # via aiohttp # via yarl diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 390a552..296a138 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -102,6 +102,7 @@ def __init__( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -124,7 +125,17 @@ def __init__( """Construct a new synchronous Replicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ReplicateError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if bearer_token is None: bearer_token = _get_api_token_from_environment() if bearer_token is None: @@ -324,6 +335,7 @@ def copy( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.Client | None = None, @@ -336,7 +348,17 @@ def copy( ) -> Self: """ Create a new client instance re-using the same options given to the current client with optional overriding. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ValueError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") @@ -477,6 +499,7 @@ def __init__( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -499,7 +522,17 @@ def __init__( """Construct a new async AsyncReplicate client instance. This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ReplicateError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if bearer_token is None: bearer_token = _get_api_token_from_environment() if bearer_token is None: @@ -699,6 +732,7 @@ def copy( self, *, bearer_token: str | None = None, + api_token: str | None = None, # Legacy compatibility parameter base_url: str | httpx.URL | None = None, timeout: float | Timeout | None | NotGiven = NOT_GIVEN, http_client: httpx.AsyncClient | None = None, @@ -711,7 +745,17 @@ def copy( ) -> Self: """ Create a new client instance re-using the same options given to the current client with optional overriding. + + For legacy compatibility, you can also pass `api_token` instead of `bearer_token`. """ + # Handle legacy api_token parameter + if api_token is not None and bearer_token is not None: + raise ValueError( + "Cannot specify both 'bearer_token' and 'api_token'. Please use 'bearer_token' (recommended) or 'api_token' for legacy compatibility." + ) + if api_token is not None: + bearer_token = api_token + if default_headers is not None and set_default_headers is not None: raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive") diff --git a/src/replicate/model.py b/src/replicate/model.py new file mode 100644 index 0000000..fd4823b --- /dev/null +++ b/src/replicate/model.py @@ -0,0 +1,35 @@ +""" +Legacy compatibility module for replicate-python v1.x type names. + +This module provides backward compatibility for code that imports types +using the old v1.x import paths like: + from replicate.model import Model + from replicate.model import Version +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .types import Prediction as Prediction + +# Import the actual types from their new locations +from .lib._models import Model as Model, Version as Version + +# Also provide aliases for the response types for type checking +if TYPE_CHECKING: + from .types import ModelGetResponse as ModelResponse + from .types.models.version_get_response import VersionGetResponse as VersionResponse +else: + # At runtime, make the response types available under legacy names + from .types import ModelGetResponse as ModelResponse + from .types.models.version_get_response import VersionGetResponse as VersionResponse + +__all__ = [ + "Model", + "Version", + "Prediction", + "ModelResponse", + "VersionResponse", +] + diff --git a/tests/test_api_token_compatibility.py b/tests/test_api_token_compatibility.py new file mode 100644 index 0000000..f45a541 --- /dev/null +++ b/tests/test_api_token_compatibility.py @@ -0,0 +1,88 @@ +"""Tests for api_token legacy compatibility during client instantiation.""" + +from __future__ import annotations + +import pytest + +from replicate import Replicate, AsyncReplicate, ReplicateError +from replicate._client import Client + + +class TestApiTokenCompatibility: + """Test that api_token parameter works as a legacy compatibility option.""" + + def test_sync_client_with_api_token(self) -> None: + """Test that Replicate accepts api_token parameter.""" + client = Replicate(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_async_client_with_api_token(self) -> None: + """Test that AsyncReplicate accepts api_token parameter.""" + client = AsyncReplicate(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_sync_client_with_bearer_token(self) -> None: + """Test that Replicate still accepts bearer_token parameter.""" + client = Replicate(bearer_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_async_client_with_bearer_token(self) -> None: + """Test that AsyncReplicate still accepts bearer_token parameter.""" + client = AsyncReplicate(bearer_token="test_token_123") + assert client.bearer_token == "test_token_123" + + def test_sync_client_both_tokens_error(self) -> None: + """Test that providing both api_token and bearer_token raises an error.""" + with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"): + Replicate(api_token="test_api", bearer_token="test_bearer") + + def test_async_client_both_tokens_error(self) -> None: + """Test that providing both api_token and bearer_token raises an error.""" + with pytest.raises(ReplicateError, match="Cannot specify both 'bearer_token' and 'api_token'"): + AsyncReplicate(api_token="test_api", bearer_token="test_bearer") + + def test_sync_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that client reads from environment when no token is provided.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123") + client = Replicate() + assert client.bearer_token == "env_token_123" + + def test_async_client_no_token_with_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that async client reads from environment when no token is provided.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token_123") + client = AsyncReplicate() + assert client.bearer_token == "env_token_123" + + def test_sync_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that client raises error when no token is provided and env is not set.""" + monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False) + with pytest.raises(ReplicateError, match="The bearer_token client option must be set"): + Replicate() + + def test_async_client_no_token_no_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that async client raises error when no token is provided and env is not set.""" + monkeypatch.delenv("REPLICATE_API_TOKEN", raising=False) + with pytest.raises(ReplicateError, match="The bearer_token client option must be set"): + AsyncReplicate() + + def test_legacy_client_alias(self) -> None: + """Test that legacy Client import still works as an alias.""" + assert Client is Replicate + + def test_legacy_client_with_api_token(self) -> None: + """Test that legacy Client alias works with api_token parameter.""" + client = Client(api_token="test_token_123") + assert client.bearer_token == "test_token_123" + assert isinstance(client, Replicate) + + def test_api_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that explicit api_token overrides environment variable.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token") + client = Replicate(api_token="explicit_token") + assert client.bearer_token == "explicit_token" + + def test_bearer_token_overrides_env(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that explicit bearer_token overrides environment variable.""" + monkeypatch.setenv("REPLICATE_API_TOKEN", "env_token") + client = Replicate(bearer_token="explicit_token") + assert client.bearer_token == "explicit_token" diff --git a/tests/test_legacy_compatibility.py b/tests/test_legacy_compatibility.py new file mode 100644 index 0000000..3a1b4e6 --- /dev/null +++ b/tests/test_legacy_compatibility.py @@ -0,0 +1,159 @@ +"""Tests for legacy v1.x type compatibility.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from replicate.types import Prediction, ModelGetResponse +from replicate.types.models.version_get_response import VersionGetResponse + + +def test_legacy_model_imports(): + """Test that legacy import paths work.""" + # Test importing Model from legacy path + from replicate.model import Model + from replicate.lib._models import Model as LibModel + + # Verify they are the same class + assert Model is LibModel + + +def test_legacy_version_imports(): + """Test that legacy Version import paths work.""" + # Test importing Version from legacy path + from replicate.model import Version + from replicate.lib._models import Version as LibVersion + + # Verify they are the same class + assert Version is LibVersion + + +def test_legacy_prediction_imports(): + """Test that legacy Prediction import paths work.""" + # Test importing Prediction from legacy path + from replicate.model import Prediction as LegacyPrediction + from replicate.types import Prediction as TypesPrediction + + # Verify they are the same class + assert LegacyPrediction is TypesPrediction + + +def test_legacy_response_type_imports(): + """Test that legacy response type aliases work.""" + # Test importing response types from legacy path + from replicate.model import ModelResponse, VersionResponse + + # Verify they are the correct types + assert ModelResponse is ModelGetResponse + assert VersionResponse is VersionGetResponse + + +def test_legacy_isinstance_checks_with_model(): + """Test that isinstance checks work with legacy Model type.""" + from replicate.model import Model + + # Create an instance + model = Model(owner="test", name="model") + + # Test isinstance check + assert isinstance(model, Model) + + +def test_legacy_isinstance_checks_with_version(): + """Test that isinstance checks work with legacy Version type.""" + import datetime + + from replicate.model import Version + + # Create a Version instance + version = Version(id="test-version-id", created_at=datetime.datetime.now(), cog_version="0.8.0", openapi_schema={}) + + # Test isinstance check + assert isinstance(version, Version) + + +def test_legacy_isinstance_checks_with_prediction(): + """Test that isinstance checks work with legacy Prediction type.""" + import datetime + + from replicate.model import Prediction + + # Create a Prediction instance using construct to bypass validation + prediction = Prediction.construct( + id="test-prediction-id", + created_at=datetime.datetime.now(), + data_removed=False, + input={}, + model="test/model", + output=None, + status="succeeded", + urls={ + "cancel": "https://example.com/cancel", + "get": "https://example.com/get", + "web": "https://example.com/web", + }, + version="test-version", + ) + + # Test isinstance check + assert isinstance(prediction, Prediction) + + +def test_legacy_isinstance_checks_with_model_response(): + """Test that isinstance checks work with ModelResponse alias.""" + from replicate.model import ModelResponse + + # Create a ModelGetResponse instance + model = ModelGetResponse.construct(name="test-model", owner="test-owner") + + # Test isinstance check with both the alias and the actual type + assert isinstance(model, ModelResponse) + assert isinstance(model, ModelGetResponse) + + +def test_legacy_isinstance_checks_with_version_response(): + """Test that isinstance checks work with VersionResponse alias.""" + import datetime + + from replicate.model import VersionResponse + + # Create a VersionGetResponse instance + version = VersionGetResponse.construct( + id="test-version-id", created_at=datetime.datetime.now(), cog_version="0.8.0", openapi_schema={} + ) + + # Test isinstance check with both the alias and the actual type + assert isinstance(version, VersionResponse) + assert isinstance(version, VersionGetResponse) + + +def test_all_exports(): + """Test that __all__ exports the expected items.""" + from replicate import model + + expected_exports = { + "Model", + "Version", + "Prediction", + "ModelResponse", + "VersionResponse", + } + + assert set(model.__all__) == expected_exports + + # Verify all exported items are importable + for name in model.__all__: + assert hasattr(model, name) + + +if TYPE_CHECKING: + # Type checking test - ensure type annotations work correctly + def type_annotation_test(): + from replicate.model import Model, Version, ModelResponse, VersionResponse + + model: Model = Model("owner", "name") # pyright: ignore[reportUnusedVariable] + version: Version # pyright: ignore[reportUnusedVariable] + prediction: Prediction # pyright: ignore[reportUnusedVariable] + model_response: ModelResponse # pyright: ignore[reportUnusedVariable] + version_response: VersionResponse # pyright: ignore[reportUnusedVariable] +