|
1 | 1 | import os |
| 2 | +from collections.abc import Generator |
2 | 3 | from typing import Any |
3 | 4 | from typing import Callable |
4 | 5 | from typing import Optional |
|
7 | 8 | import httpx |
8 | 9 |
|
9 | 10 | from singlestoredb import manage_workspaces |
| 11 | +from singlestoredb.management.inference_api import InferenceAPIInfo |
10 | 12 |
|
11 | 13 | try: |
12 | 14 | from langchain_openai import ChatOpenAI |
|
31 | 33 |
|
32 | 34 | def SingleStoreChatFactory( |
33 | 35 | model_name: str, |
34 | | - api_key: Optional[str] = None, |
| 36 | + api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, |
35 | 37 | streaming: bool = True, |
36 | 38 | http_client: Optional[httpx.Client] = None, |
37 | | - obo_token_getter: Optional[Callable[[], Optional[str]]] = None, |
| 39 | + obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None, |
| 40 | + base_url: Optional[str] = None, |
| 41 | + hosting_platform: Optional[str] = None, |
38 | 42 | **kwargs: Any, |
39 | 43 | ) -> Union[ChatOpenAI, ChatBedrockConverse]: |
40 | 44 | """Return a chat model instance (ChatOpenAI or ChatBedrockConverse). |
41 | 45 | """ |
42 | | - inference_api_manager = ( |
43 | | - manage_workspaces().organizations.current.inference_apis |
44 | | - ) |
45 | | - info = inference_api_manager.get(model_name=model_name) |
46 | | - token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') |
47 | | - token = api_key if api_key is not None else token_env |
| 46 | + # Handle api_key and obo_token as callable functions |
| 47 | + if callable(api_key): |
| 48 | + api_key_getter = api_key |
| 49 | + else: |
| 50 | + def api_key_getter() -> Optional[str]: |
| 51 | + if api_key is None: |
| 52 | + return os.environ.get('SINGLESTOREDB_USER_TOKEN') |
| 53 | + return api_key |
| 54 | + |
| 55 | + if callable(obo_token): |
| 56 | + obo_token_getter = obo_token |
| 57 | + else: |
| 58 | + def obo_token_getter() -> Optional[str]: |
| 59 | + return obo_token |
| 60 | + |
| 61 | + # handle model info |
| 62 | + if base_url is None: |
| 63 | + base_url = os.environ.get('SINGLESTOREDB_INFERENCE_API_BASE_URL') |
| 64 | + if hosting_platform is None: |
| 65 | + hosting_platform = os.environ.get('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM') |
| 66 | + if base_url is None or hosting_platform is None: |
| 67 | + inference_api_manager = ( |
| 68 | + manage_workspaces().organizations.current.inference_apis |
| 69 | + ) |
| 70 | + info = inference_api_manager.get(model_name=model_name) |
| 71 | + else: |
| 72 | + info = InferenceAPIInfo( |
| 73 | + service_id='', |
| 74 | + model_name=model_name, |
| 75 | + name='', |
| 76 | + connection_url=base_url, |
| 77 | + project_id='', |
| 78 | + hosting_platform=hosting_platform, |
| 79 | + ) |
| 80 | + if base_url is not None: |
| 81 | + info.connection_url = base_url |
| 82 | + if hosting_platform is not None: |
| 83 | + info.hosting_platform = hosting_platform |
48 | 84 |
|
49 | 85 | if info.hosting_platform == 'Amazon': |
50 | 86 | # Instantiate Bedrock client |
@@ -86,12 +122,14 @@ def SingleStoreChatFactory( |
86 | 122 |
|
87 | 123 | def _inject_headers(request: Any, **_ignored: Any) -> None: |
88 | 124 | """Inject dynamic auth/OBO headers prior to Bedrock sending.""" |
| 125 | + if api_key_getter is not None: |
| 126 | + token_val = api_key_getter() |
| 127 | + if token_val: |
| 128 | + request.headers['Authorization'] = f'Bearer {token_val}' |
89 | 129 | if obo_token_getter is not None: |
90 | 130 | obo_val = obo_token_getter() |
91 | 131 | if obo_val: |
92 | 132 | request.headers['X-S2-OBO'] = obo_val |
93 | | - if token: |
94 | | - request.headers['Authorization'] = f'Bearer {token}' |
95 | 133 | request.headers.pop('X-Amz-Date', None) |
96 | 134 | request.headers.pop('X-Amz-Security-Token', None) |
97 | 135 |
|
@@ -124,10 +162,29 @@ def _inject_headers(request: Any, **_ignored: Any) -> None: |
124 | 162 | **kwargs, |
125 | 163 | ) |
126 | 164 |
|
| 165 | + class OpenAIAuth(httpx.Auth): |
| 166 | + def auth_flow( |
| 167 | + self, request: httpx.Request, |
| 168 | + ) -> Generator[httpx.Request, None, None]: |
| 169 | + if api_key_getter is not None: |
| 170 | + token_val = api_key_getter() |
| 171 | + if token_val: |
| 172 | + request.headers['Authorization'] = f'Bearer {token_val}' |
| 173 | + if obo_token_getter is not None: |
| 174 | + obo_val = obo_token_getter() |
| 175 | + if obo_val: |
| 176 | + request.headers['X-S2-OBO'] = obo_val |
| 177 | + yield request |
| 178 | + |
| 179 | + http_client = httpx.Client( |
| 180 | + timeout=30, |
| 181 | + auth=OpenAIAuth(), |
| 182 | + ) |
| 183 | + |
127 | 184 | # OpenAI / Azure OpenAI path |
128 | 185 | openai_kwargs = dict( |
129 | 186 | base_url=info.connection_url, |
130 | | - api_key=token, |
| 187 | + api_key='placeholder', |
131 | 188 | model=model_name, |
132 | 189 | streaming=streaming, |
133 | 190 | ) |
|
0 commit comments