Skip to content

Commit 353e969

Browse files
committed
feat(pipelines): get API token from cog's current_scope, if available
1 parent 914202d commit 353e969

File tree

2 files changed

+255
-2
lines changed

2 files changed

+255
-2
lines changed

src/replicate/_client.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Iterator,
1515
Optional,
1616
AsyncIterator,
17+
cast,
1718
overload,
1819
)
1920
from typing_extensions import Self, Unpack, ParamSpec, override
@@ -76,6 +77,38 @@
7677
]
7778

7879

80+
def _get_api_token_from_environment() -> str | None:
81+
"""Get API token from cog current scope if available, otherwise from environment."""
82+
try:
83+
import cog # type: ignore[import-untyped, import-not-found]
84+
85+
# Get the current scope - this might return None or raise an exception
86+
scope = getattr(cog, "current_scope", lambda: None)()
87+
if scope is None:
88+
return os.environ.get("REPLICATE_API_TOKEN")
89+
90+
# Get the context from the scope
91+
context = getattr(scope, "context", None)
92+
if context is None:
93+
return os.environ.get("REPLICATE_API_TOKEN")
94+
95+
# Get the items method and call it
96+
items_method = getattr(context, "items", None)
97+
if not callable(items_method):
98+
return os.environ.get("REPLICATE_API_TOKEN")
99+
100+
# Iterate through context items looking for the API token
101+
items = cast(Iterator[tuple[Any, Any]], items_method())
102+
for key, value in items:
103+
if str(key).upper() == "REPLICATE_API_TOKEN":
104+
return str(value) if value is not None else value
105+
106+
except Exception: # Catch all exceptions to ensure robust fallback
107+
pass
108+
109+
return os.environ.get("REPLICATE_API_TOKEN")
110+
111+
79112
class Replicate(SyncAPIClient):
80113
# client options
81114
bearer_token: str
@@ -108,7 +141,7 @@ def __init__(
108141
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
109142
"""
110143
if bearer_token is None:
111-
bearer_token = os.environ.get("REPLICATE_API_TOKEN")
144+
bearer_token = _get_api_token_from_environment()
112145
if bearer_token is None:
113146
raise ReplicateError(
114147
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
@@ -419,7 +452,7 @@ def __init__(
419452
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
420453
"""
421454
if bearer_token is None:
422-
bearer_token = os.environ.get("REPLICATE_API_TOKEN")
455+
bearer_token = _get_api_token_from_environment()
423456
if bearer_token is None:
424457
raise ReplicateError(
425458
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"

tests/test_current_scope.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""Tests for current_scope token functionality."""
2+
3+
import os
4+
import sys
5+
from unittest import mock
6+
7+
import pytest
8+
9+
from replicate import Replicate, AsyncReplicate
10+
from replicate._client import _get_api_token_from_environment
11+
from replicate._exceptions import ReplicateError
12+
13+
14+
class TestGetApiTokenFromEnvironment:
15+
"""Test the _get_api_token_from_environment function."""
16+
17+
def test_cog_no_current_scope_method_falls_back_to_env(self):
18+
"""Test fallback when cog exists but has no current_scope method."""
19+
mock_cog = mock.MagicMock()
20+
del mock_cog.current_scope # Remove the method
21+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
22+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
23+
token = _get_api_token_from_environment()
24+
assert token == "env-token"
25+
26+
def test_cog_current_scope_returns_none_falls_back_to_env(self):
27+
"""Test fallback when current_scope() returns None."""
28+
mock_cog = mock.MagicMock()
29+
mock_cog.current_scope.return_value = None
30+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
31+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
32+
token = _get_api_token_from_environment()
33+
assert token == "env-token"
34+
35+
def test_cog_scope_no_context_attr_falls_back_to_env(self):
36+
"""Test fallback when scope has no context attribute."""
37+
mock_scope = mock.MagicMock()
38+
del mock_scope.context # Remove the context attribute
39+
mock_cog = mock.MagicMock()
40+
mock_cog.current_scope.return_value = mock_scope
41+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
42+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
43+
token = _get_api_token_from_environment()
44+
assert token == "env-token"
45+
46+
def test_cog_scope_context_not_dict_falls_back_to_env(self):
47+
"""Test fallback when scope.context is not a dictionary."""
48+
mock_scope = mock.MagicMock()
49+
mock_scope.context = "not a dict"
50+
mock_cog = mock.MagicMock()
51+
mock_cog.current_scope.return_value = mock_scope
52+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
53+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
54+
token = _get_api_token_from_environment()
55+
assert token == "env-token"
56+
57+
def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
58+
"""Test fallback when replicate_api_token key is missing from context."""
59+
mock_scope = mock.MagicMock()
60+
mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token
61+
mock_cog = mock.MagicMock()
62+
mock_cog.current_scope.return_value = mock_scope
63+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
64+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
65+
token = _get_api_token_from_environment()
66+
assert token == "env-token"
67+
68+
def test_cog_scope_replicate_api_token_valid_string(self):
69+
"""Test successful retrieval of non-empty token from cog."""
70+
mock_scope = mock.MagicMock()
71+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
72+
mock_cog = mock.MagicMock()
73+
mock_cog.current_scope.return_value = mock_scope
74+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
75+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
76+
token = _get_api_token_from_environment()
77+
assert token == "cog-token"
78+
79+
def test_cog_scope_replicate_api_token_case_insensitive(self):
80+
"""Test successful retrieval of non-empty token from cog ignoring case."""
81+
mock_scope = mock.MagicMock()
82+
mock_scope.context = {"replicate_api_token": "cog-token"}
83+
mock_cog = mock.MagicMock()
84+
mock_cog.current_scope.return_value = mock_scope
85+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
86+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
87+
token = _get_api_token_from_environment()
88+
assert token == "cog-token"
89+
90+
def test_cog_scope_replicate_api_token_empty_string(self):
91+
"""Test that empty string from cog is returned (not falling back to env)."""
92+
mock_scope = mock.MagicMock()
93+
mock_scope.context = {"replicate_api_token": ""} # Empty string
94+
mock_cog = mock.MagicMock()
95+
mock_cog.current_scope.return_value = mock_scope
96+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
97+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
98+
token = _get_api_token_from_environment()
99+
assert token == "" # Should return empty string, not env token
100+
101+
def test_cog_scope_replicate_api_token_none(self):
102+
"""Test that None from cog is returned (not falling back to env)."""
103+
mock_scope = mock.MagicMock()
104+
mock_scope.context = {"replicate_api_token": None}
105+
mock_cog = mock.MagicMock()
106+
mock_cog.current_scope.return_value = mock_scope
107+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
108+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
109+
token = _get_api_token_from_environment()
110+
assert token is None # Should return None, not env token
111+
112+
def test_cog_current_scope_raises_exception_falls_back_to_env(self):
113+
"""Test fallback when current_scope() raises an exception."""
114+
mock_cog = mock.MagicMock()
115+
mock_cog.current_scope.side_effect = RuntimeError("Scope error")
116+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
117+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
118+
token = _get_api_token_from_environment()
119+
assert token == "env-token"
120+
121+
def test_no_env_token_returns_none(self):
122+
"""Test that None is returned when no environment token is set and cog unavailable."""
123+
with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars
124+
with mock.patch.dict(sys.modules, {"cog": None}):
125+
token = _get_api_token_from_environment()
126+
assert token is None
127+
128+
def test_env_token_empty_string(self):
129+
"""Test that empty string from environment is returned."""
130+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}):
131+
with mock.patch.dict(sys.modules, {"cog": None}):
132+
token = _get_api_token_from_environment()
133+
assert token == ""
134+
135+
def test_env_token_valid_string(self):
136+
"""Test that valid token from environment is returned."""
137+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
138+
with mock.patch.dict(sys.modules, {"cog": None}):
139+
token = _get_api_token_from_environment()
140+
assert token == "env-token"
141+
142+
143+
class TestClientCurrentScopeIntegration:
144+
"""Test that the client uses current_scope functionality."""
145+
146+
def test_sync_client_uses_current_scope_token(self):
147+
"""Test that sync client retrieves token from current_scope."""
148+
mock_scope = mock.MagicMock()
149+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
150+
mock_cog = mock.MagicMock()
151+
mock_cog.current_scope.return_value = mock_scope
152+
153+
# Clear environment variable to ensure we're using cog
154+
with mock.patch.dict(os.environ, {}, clear=True):
155+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
156+
client = Replicate(base_url="http://test.example.com")
157+
assert client.bearer_token == "cog-token"
158+
159+
def test_async_client_uses_current_scope_token(self):
160+
"""Test that async client retrieves token from current_scope."""
161+
mock_scope = mock.MagicMock()
162+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
163+
mock_cog = mock.MagicMock()
164+
mock_cog.current_scope.return_value = mock_scope
165+
166+
# Clear environment variable to ensure we're using cog
167+
with mock.patch.dict(os.environ, {}, clear=True):
168+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
169+
client = AsyncReplicate(base_url="http://test.example.com")
170+
assert client.bearer_token == "cog-token"
171+
172+
def test_sync_client_falls_back_to_env_when_cog_unavailable(self):
173+
"""Test that sync client falls back to env when cog is unavailable."""
174+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
175+
with mock.patch.dict(sys.modules, {"cog": None}):
176+
client = Replicate(base_url="http://test.example.com")
177+
assert client.bearer_token == "env-token"
178+
179+
def test_async_client_falls_back_to_env_when_cog_unavailable(self):
180+
"""Test that async client falls back to env when cog is unavailable."""
181+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
182+
with mock.patch.dict(sys.modules, {"cog": None}):
183+
client = AsyncReplicate(base_url="http://test.example.com")
184+
assert client.bearer_token == "env-token"
185+
186+
def test_sync_client_raises_error_when_no_token_available(self):
187+
"""Test that sync client raises error when no token is available."""
188+
with mock.patch.dict(os.environ, {}, clear=True):
189+
with mock.patch.dict(sys.modules, {"cog": None}):
190+
with pytest.raises(ReplicateError, match="bearer_token client option must be set"):
191+
Replicate(base_url="http://test.example.com")
192+
193+
def test_async_client_raises_error_when_no_token_available(self):
194+
"""Test that async client raises error when no token is available."""
195+
with mock.patch.dict(os.environ, {}, clear=True):
196+
with mock.patch.dict(sys.modules, {"cog": None}):
197+
with pytest.raises(ReplicateError, match="bearer_token client option must be set"):
198+
AsyncReplicate(base_url="http://test.example.com")
199+
200+
def test_explicit_token_overrides_current_scope(self):
201+
"""Test that explicitly provided token overrides current_scope."""
202+
mock_scope = mock.MagicMock()
203+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
204+
mock_cog = mock.MagicMock()
205+
mock_cog.current_scope.return_value = mock_scope
206+
207+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
208+
client = Replicate(bearer_token="explicit-token", base_url="http://test.example.com")
209+
assert client.bearer_token == "explicit-token"
210+
211+
def test_explicit_async_token_overrides_current_scope(self):
212+
"""Test that explicitly provided token overrides current_scope for async client."""
213+
mock_scope = mock.MagicMock()
214+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
215+
mock_cog = mock.MagicMock()
216+
mock_cog.current_scope.return_value = mock_scope
217+
218+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
219+
client = AsyncReplicate(bearer_token="explicit-token", base_url="http://test.example.com")
220+
assert client.bearer_token == "explicit-token"

0 commit comments

Comments
 (0)