Skip to content

Commit 0f83705

Browse files
Apply same changes to ChatFactory as well.
1 parent e5e72a5 commit 0f83705

File tree

2 files changed

+68
-12
lines changed

2 files changed

+68
-12
lines changed

singlestoredb/ai/chat.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from collections.abc import Generator
23
from typing import Any
34
from typing import Callable
45
from typing import Optional
@@ -7,6 +8,7 @@
78
import httpx
89

910
from singlestoredb import manage_workspaces
11+
from singlestoredb.management.inference_api import InferenceAPIInfo
1012

1113
try:
1214
from langchain_openai import ChatOpenAI
@@ -31,20 +33,54 @@
3133

3234
def SingleStoreChatFactory(
3335
model_name: str,
34-
api_key: Optional[str] = None,
36+
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3537
streaming: bool = True,
3638
http_client: Optional[httpx.Client] = None,
37-
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
39+
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
40+
base_url: Optional[str] = None,
41+
hosting_platform: Optional[str] = None,
3842
**kwargs: Any,
3943
) -> Union[ChatOpenAI, ChatBedrockConverse]:
4044
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
4145
"""
42-
inference_api_manager = (
43-
manage_workspaces().organizations.current.inference_apis
44-
)
45-
info = inference_api_manager.get(model_name=model_name)
46-
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
47-
token = api_key if api_key is not None else token_env
46+
# Handle api_key and obo_token as callable functions
47+
if callable(api_key):
48+
api_key_getter = api_key
49+
else:
50+
def api_key_getter() -> Optional[str]:
51+
if api_key is None:
52+
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
53+
return api_key
54+
55+
if callable(obo_token):
56+
obo_token_getter = obo_token
57+
else:
58+
def obo_token_getter() -> Optional[str]:
59+
return obo_token
60+
61+
# handle model info
62+
if base_url is None:
63+
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
64+
if hosting_platform is None:
65+
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
66+
if base_url is None or hosting_platform is None:
67+
inference_api_manager = (
68+
manage_workspaces().organizations.current.inference_apis
69+
)
70+
info = inference_api_manager.get(model_name=model_name)
71+
else:
72+
info = InferenceAPIInfo(
73+
service_id='',
74+
model_name=model_name,
75+
name='',
76+
connection_url=base_url,
77+
project_id='',
78+
hosting_platform=hosting_platform,
79+
)
80+
if base_url is not None:
81+
info.connection_url = base_url
82+
if hosting_platform is not None:
83+
info.hosting_platform = hosting_platform
4884

4985
if info.hosting_platform == 'Amazon':
5086
# Instantiate Bedrock client
@@ -86,12 +122,14 @@ def SingleStoreChatFactory(
86122

87123
def _inject_headers(request: Any, **_ignored: Any) -> None:
88124
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
125+
if api_key_getter is not None:
126+
token_val = api_key_getter()
127+
if token_val:
128+
request.headers['Authorization'] = f'Bearer {token_val}'
89129
if obo_token_getter is not None:
90130
obo_val = obo_token_getter()
91131
if obo_val:
92132
request.headers['X-S2-OBO'] = obo_val
93-
if token:
94-
request.headers['Authorization'] = f'Bearer {token}'
95133
request.headers.pop('X-Amz-Date', None)
96134
request.headers.pop('X-Amz-Security-Token', None)
97135

@@ -124,10 +162,29 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
124162
**kwargs,
125163
)
126164

165+
class OpenAIAuth(httpx.Auth):
166+
def auth_flow(
167+
self, request: httpx.Request,
168+
) -> Generator[httpx.Request, None, None]:
169+
if api_key_getter is not None:
170+
token_val = api_key_getter()
171+
if token_val:
172+
request.headers['Authorization'] = f'Bearer {token_val}'
173+
if obo_token_getter is not None:
174+
obo_val = obo_token_getter()
175+
if obo_val:
176+
request.headers['X-S2-OBO'] = obo_val
177+
yield request
178+
179+
http_client = httpx.Client(
180+
timeout=30,
181+
auth=OpenAIAuth(),
182+
)
183+
127184
# OpenAI / Azure OpenAI path
128185
openai_kwargs = dict(
129186
base_url=info.connection_url,
130-
api_key=token,
187+
api_key='placeholder',
131188
model=model_name,
132189
streaming=streaming,
133190
)

singlestoredb/ai/embeddings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def obo_token_getter() -> Optional[str]:
6262
base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL')
6363
if hosting_platform is None:
6464
hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM')
65-
6665
if base_url is None or hosting_platform is None:
6766
inference_api_manager = (
6867
manage_workspaces().organizations.current.inference_apis

0 commit comments

Comments
 (0)