Skip to content

Commit 4b5f45c

Browse files
claude[bot]zeke
andcommitted
feat: improve cog integration with error logging and better organization
- Add debug-level logging for cog.current_scope() failures - Move cog logic to dedicated lib/cog.py module for better organization - Update imports and tests to use new module structure Addresses @dgellow's review suggestions: 1. Added error logging at debug level for better troubleshooting 2. Moved cog-specific code to lib/cog.py to make it obvious it's custom code and reduce risk of git conflicts Co-authored-by: Zeke Sikelianos <[email protected]>
1 parent 353e969 commit 4b5f45c

File tree

3 files changed

+76
-48
lines changed

3 files changed

+76
-48
lines changed

src/replicate/_client.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from replicate.lib._files import FileEncodingStrategy
2525
from replicate.lib._predictions_run import Model, Version, ModelVersionIdentifier
26+
from replicate.lib.cog import get_api_token_from_environment
2627
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
2728

2829
from . import _exceptions
@@ -77,37 +78,6 @@
7778
]
7879

7980

80-
def _get_api_token_from_environment() -> str | None:
81-
"""Get API token from cog current scope if available, otherwise from environment."""
82-
try:
83-
import cog # type: ignore[import-untyped, import-not-found]
84-
85-
# Get the current scope - this might return None or raise an exception
86-
scope = getattr(cog, "current_scope", lambda: None)()
87-
if scope is None:
88-
return os.environ.get("REPLICATE_API_TOKEN")
89-
90-
# Get the context from the scope
91-
context = getattr(scope, "context", None)
92-
if context is None:
93-
return os.environ.get("REPLICATE_API_TOKEN")
94-
95-
# Get the items method and call it
96-
items_method = getattr(context, "items", None)
97-
if not callable(items_method):
98-
return os.environ.get("REPLICATE_API_TOKEN")
99-
100-
# Iterate through context items looking for the API token
101-
items = cast(Iterator[tuple[Any, Any]], items_method())
102-
for key, value in items:
103-
if str(key).upper() == "REPLICATE_API_TOKEN":
104-
return str(value) if value is not None else value
105-
106-
except Exception: # Catch all exceptions to ensure robust fallback
107-
pass
108-
109-
return os.environ.get("REPLICATE_API_TOKEN")
110-
11181

11282
class Replicate(SyncAPIClient):
11383
# client options
@@ -141,7 +111,7 @@ def __init__(
141111
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
142112
"""
143113
if bearer_token is None:
144-
bearer_token = _get_api_token_from_environment()
114+
bearer_token = get_api_token_from_environment()
145115
if bearer_token is None:
146116
raise ReplicateError(
147117
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
@@ -452,7 +422,7 @@ def __init__(
452422
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
453423
"""
454424
if bearer_token is None:
455-
bearer_token = _get_api_token_from_environment()
425+
bearer_token = get_api_token_from_environment()
456426
if bearer_token is None:
457427
raise ReplicateError(
458428
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"

src/replicate/lib/cog.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Integration with cog's current_scope functionality."""
2+
3+
import os
4+
import logging
5+
from typing import Iterator, Any, cast
6+
7+
logger = logging.getLogger("replicate")
8+
9+
10+
def get_api_token_from_current_scope() -> str | None:
11+
"""Get API token from cog's current_scope if available, otherwise return None.
12+
13+
This function attempts to retrieve the API token from cog's current_scope context.
14+
It gracefully handles all errors and returns None if cog is not available or if
15+
any part of the retrieval process fails.
16+
17+
Returns:
18+
str | None: The API token from cog's current_scope, or None if not available.
19+
"""
20+
try:
21+
import cog # type: ignore[import-untyped, import-not-found]
22+
23+
# Get the current scope - this might return None or raise an exception
24+
scope = getattr(cog, "current_scope", lambda: None)()
25+
if scope is None:
26+
return None
27+
28+
# Get the context from the scope
29+
context = getattr(scope, "context", None)
30+
if context is None:
31+
return None
32+
33+
# Get the items method and call it
34+
items_method = getattr(context, "items", None)
35+
if not callable(items_method):
36+
return None
37+
38+
# Iterate through context items looking for the API token
39+
items = cast(Iterator[tuple[Any, Any]], items_method())
40+
for key, value in items:
41+
if str(key).upper() == "REPLICATE_API_TOKEN":
42+
return str(value) if value is not None else value
43+
44+
except Exception as e: # Catch all exceptions to ensure robust fallback
45+
logger.debug("Failed to retrieve API token from cog.current_scope(): %s", e)
46+
47+
return None
48+
49+
50+
def get_api_token_from_environment() -> str | None:
51+
"""Get API token from cog current scope if available, otherwise from environment."""
52+
# Try to get token from cog's current_scope first
53+
token = get_api_token_from_current_scope()
54+
if token is not None:
55+
return token
56+
57+
# Fall back to environment variable
58+
return os.environ.get("REPLICATE_API_TOKEN")

tests/test_current_scope.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@
77
import pytest
88

99
from replicate import Replicate, AsyncReplicate
10-
from replicate._client import _get_api_token_from_environment
10+
from replicate.lib.cog import get_api_token_from_environment
1111
from replicate._exceptions import ReplicateError
1212

1313

1414
class TestGetApiTokenFromEnvironment:
15-
"""Test the _get_api_token_from_environment function."""
15+
"""Test the get_api_token_from_environment function."""
1616

1717
def test_cog_no_current_scope_method_falls_back_to_env(self):
1818
"""Test fallback when cog exists but has no current_scope method."""
1919
mock_cog = mock.MagicMock()
2020
del mock_cog.current_scope # Remove the method
2121
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
2222
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
23-
token = _get_api_token_from_environment()
23+
token = get_api_token_from_environment()
2424
assert token == "env-token"
2525

2626
def test_cog_current_scope_returns_none_falls_back_to_env(self):
@@ -29,7 +29,7 @@ def test_cog_current_scope_returns_none_falls_back_to_env(self):
2929
mock_cog.current_scope.return_value = None
3030
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
3131
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
32-
token = _get_api_token_from_environment()
32+
token = get_api_token_from_environment()
3333
assert token == "env-token"
3434

3535
def test_cog_scope_no_context_attr_falls_back_to_env(self):
@@ -40,7 +40,7 @@ def test_cog_scope_no_context_attr_falls_back_to_env(self):
4040
mock_cog.current_scope.return_value = mock_scope
4141
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
4242
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
43-
token = _get_api_token_from_environment()
43+
token = get_api_token_from_environment()
4444
assert token == "env-token"
4545

4646
def test_cog_scope_context_not_dict_falls_back_to_env(self):
@@ -51,7 +51,7 @@ def test_cog_scope_context_not_dict_falls_back_to_env(self):
5151
mock_cog.current_scope.return_value = mock_scope
5252
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
5353
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
54-
token = _get_api_token_from_environment()
54+
token = get_api_token_from_environment()
5555
assert token == "env-token"
5656

5757
def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
@@ -62,7 +62,7 @@ def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
6262
mock_cog.current_scope.return_value = mock_scope
6363
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
6464
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
65-
token = _get_api_token_from_environment()
65+
token = get_api_token_from_environment()
6666
assert token == "env-token"
6767

6868
def test_cog_scope_replicate_api_token_valid_string(self):
@@ -73,7 +73,7 @@ def test_cog_scope_replicate_api_token_valid_string(self):
7373
mock_cog.current_scope.return_value = mock_scope
7474
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
7575
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
76-
token = _get_api_token_from_environment()
76+
token = get_api_token_from_environment()
7777
assert token == "cog-token"
7878

7979
def test_cog_scope_replicate_api_token_case_insensitive(self):
@@ -84,7 +84,7 @@ def test_cog_scope_replicate_api_token_case_insensitive(self):
8484
mock_cog.current_scope.return_value = mock_scope
8585
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
8686
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
87-
token = _get_api_token_from_environment()
87+
token = get_api_token_from_environment()
8888
assert token == "cog-token"
8989

9090
def test_cog_scope_replicate_api_token_empty_string(self):
@@ -95,7 +95,7 @@ def test_cog_scope_replicate_api_token_empty_string(self):
9595
mock_cog.current_scope.return_value = mock_scope
9696
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
9797
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
98-
token = _get_api_token_from_environment()
98+
token = get_api_token_from_environment()
9999
assert token == "" # Should return empty string, not env token
100100

101101
def test_cog_scope_replicate_api_token_none(self):
@@ -106,7 +106,7 @@ def test_cog_scope_replicate_api_token_none(self):
106106
mock_cog.current_scope.return_value = mock_scope
107107
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
108108
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
109-
token = _get_api_token_from_environment()
109+
token = get_api_token_from_environment()
110110
assert token is None # Should return None, not env token
111111

112112
def test_cog_current_scope_raises_exception_falls_back_to_env(self):
@@ -115,28 +115,28 @@ def test_cog_current_scope_raises_exception_falls_back_to_env(self):
115115
mock_cog.current_scope.side_effect = RuntimeError("Scope error")
116116
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
117117
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
118-
token = _get_api_token_from_environment()
118+
token = get_api_token_from_environment()
119119
assert token == "env-token"
120120

121121
def test_no_env_token_returns_none(self):
122122
"""Test that None is returned when no environment token is set and cog unavailable."""
123123
with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars
124124
with mock.patch.dict(sys.modules, {"cog": None}):
125-
token = _get_api_token_from_environment()
125+
token = get_api_token_from_environment()
126126
assert token is None
127127

128128
def test_env_token_empty_string(self):
129129
"""Test that empty string from environment is returned."""
130130
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}):
131131
with mock.patch.dict(sys.modules, {"cog": None}):
132-
token = _get_api_token_from_environment()
132+
token = get_api_token_from_environment()
133133
assert token == ""
134134

135135
def test_env_token_valid_string(self):
136136
"""Test that valid token from environment is returned."""
137137
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
138138
with mock.patch.dict(sys.modules, {"cog": None}):
139-
token = _get_api_token_from_environment()
139+
token = get_api_token_from_environment()
140140
assert token == "env-token"
141141

142142

0 commit comments

Comments
 (0)