Skip to content

Commit e5e72a5

Browse files
fix: Make auth tokens resolved dynamically per request.
1 parent 4adf1c5 commit e5e72a5

File tree

1 file changed

+69
-11
lines changed

1 file changed

+69
-11
lines changed

singlestoredb/ai/embeddings.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -7,6 +8,7 @@
78
import httpx
89

910
from singlestoredb import manage_workspaces
11+
from singlestoredb.management.inference_api import InferenceAPIInfo
1012

1113
try:
1214
from langchain_openai import OpenAIEmbeddings
@@ -31,19 +33,54 @@
3133

3234
def SingleStoreEmbeddingsFactory(
3335
model_name: str,
34-
api_key: Optional[str] = None,
36+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3537
http_client: Optional[httpx.Client] = None,
36-
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
38+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
39+
base_url: Optional[str] = None,
40+
hosting_platform: Optional[str] = None,
3741
**kwargs: Any,
3842
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
3943
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4044
"""
41-
inference_api_manager = (
42-
manage_workspaces().organizations.current.inference_apis
43-
)
44-
info = inference_api_manager.get(model_name=model_name)
45-
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
46-
token = api_key if api_key is not None else token_env
45+
# Handle api_key and obo_token as callable functions
46+
if callable(api_key):
47+
api_key_getter = api_key
48+
else:
49+
def api_key_getter() -> Optional[str]:
50+
if api_key is None:
51+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
52+
return api_key
53+
54+
if callable(obo_token):
55+
obo_token_getter = obo_token
56+
else:
57+
def obo_token_getter() -> Optional[str]:
58+
return obo_token
59+
60+
# handle model info
61+
if base_url is None:
62+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
63+
if hosting_platform is None:
64+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
65+
66+
if base_url is None or hosting_platform is None:
67+
inference_api_manager = (
68+
manage_workspaces().organizations.current.inference_apis
69+
)
70+
info = inference_api_manager.get(model_name=model_name)
71+
else:
72+
info = InferenceAPIInfo(
73+
service_id='',
74+
model_name=model_name,
75+
name='',
76+
connection_url=base_url,
77+
project_id='',
78+
hosting_platform=hosting_platform,
79+
)
80+
if base_url is not None:
81+
info.connection_url = base_url
82+
if hosting_platform is not None:
83+
info.hosting_platform = hosting_platform
4784

4885
if info.hosting_platform == 'Amazon':
4986
# Instantiate Bedrock client
@@ -85,12 +122,14 @@ def SingleStoreEmbeddingsFactory(
85122

86123
def _inject_headers(request: Any, **_ignored: Any) -> None:
87124
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
125+
if api_key_getter is not None:
126+
token_val = api_key_getter()
127+
if token_val:
128+
request.headers['Authorization'] = f'Bearer {token_val}'
88129
if obo_token_getter is not None:
89130
obo_val = obo_token_getter()
90131
if obo_val:
91132
request.headers['X-S2-OBO'] = obo_val
92-
if token:
93-
request.headers['Authorization'] = f'Bearer {token}'
94133
request.headers.pop('X-Amz-Date', None)
95134
request.headers.pop('X-Amz-Security-Token', None)
96135

@@ -114,10 +153,29 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
114153
**kwargs,
115154
)
116155

156+
class OpenAIAuth(httpx.Auth):
157+
def auth_flow(
158+
self, request: httpx.Request,
159+
) -> Generator[httpx.Request, None, None]:
160+
if api_key_getter is not None:
161+
token_val = api_key_getter()
162+
if token_val:
163+
request.headers['Authorization'] = f'Bearer {token_val}'
164+
if obo_token_getter is not None:
165+
obo_val = obo_token_getter()
166+
if obo_val:
167+
request.headers['X-S2-OBO'] = obo_val
168+
yield request
169+
170+
http_client = httpx.Client(
171+
timeout=30,
172+
auth=OpenAIAuth(),
173+
)
174+
117175
# OpenAI / Azure OpenAI path
118176
openai_kwargs = dict(
119177
base_url=info.connection_url,
120-
api_key=token,
178+
api_key='placeholder',
121179
model=model_name,
122180
)
123181
if http_client is not None:

0 commit comments

Comments
 (0)