11import os
2+ from collections .abc import Generator
23from typing import Any
34from typing import Callable
45from typing import Optional
78import httpx
89
910from singlestoredb import manage_workspaces
11+ from singlestoredb .management .inference_api import InferenceAPIInfo
1012
1113try :
1214 from langchain_openai import OpenAIEmbeddings
3133
3234def SingleStoreEmbeddingsFactory (
3335 model_name : str ,
34- api_key : Optional [str ] = None ,
36+ api_key : Optional [Union [ Optional [ str ], Callable [[], Optional [ str ]]] ] = None ,
3537 http_client : Optional [httpx .Client ] = None ,
36- obo_token_getter : Optional [Callable [[], Optional [str ]]] = None ,
38+ obo_token : Optional [Union [Optional [str ], Callable [[], Optional [str ]]]] = None ,
39+ base_url : Optional [str ] = None ,
40+ hosting_platform : Optional [str ] = None ,
3741 ** kwargs : Any ,
3842) -> Union [OpenAIEmbeddings , BedrockEmbeddings ]:
3943 """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4044 """
41- inference_api_manager = (
42- manage_workspaces ().organizations .current .inference_apis
43- )
44- info = inference_api_manager .get (model_name = model_name )
45- token_env = os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
46- token = api_key if api_key is not None else token_env
45+ # Handle api_key and obo_token as callable functions
46+ if callable (api_key ):
47+ api_key_getter = api_key
48+ else :
49+ def api_key_getter () -> Optional [str ]:
50+ if api_key is None :
51+ return os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
52+ return api_key
53+
54+ if callable (obo_token ):
55+ obo_token_getter = obo_token
56+ else :
57+ def obo_token_getter () -> Optional [str ]:
58+ return obo_token
59+
60+ # handle model info
61+ if base_url is None :
62+ base_url = os .environ .get ('SINGLESTOREDB_INFERENCE_API_BASE_URL' )
63+ if hosting_platform is None :
64+ hosting_platform = os .environ .get ('SINGLESTOREDB_INFERENCE_API_HOSTING_PLATFORM' )
65+
66+ if base_url is None or hosting_platform is None :
67+ inference_api_manager = (
68+ manage_workspaces ().organizations .current .inference_apis
69+ )
70+ info = inference_api_manager .get (model_name = model_name )
71+ else :
72+ info = InferenceAPIInfo (
73+ service_id = '' ,
74+ model_name = model_name ,
75+ name = '' ,
76+ connection_url = base_url ,
77+ project_id = '' ,
78+ hosting_platform = hosting_platform ,
79+ )
80+ if base_url is not None :
81+ info .connection_url = base_url
82+ if hosting_platform is not None :
83+ info .hosting_platform = hosting_platform
4784
4885 if info .hosting_platform == 'Amazon' :
4986 # Instantiate Bedrock client
@@ -85,12 +122,14 @@ def SingleStoreEmbeddingsFactory(
85122
86123 def _inject_headers (request : Any , ** _ignored : Any ) -> None :
87124 """Inject dynamic auth/OBO headers prior to Bedrock sending."""
125+ if api_key_getter is not None :
126+ token_val = api_key_getter ()
127+ if token_val :
128+ request .headers ['Authorization' ] = f'Bearer { token_val } '
88129 if obo_token_getter is not None :
89130 obo_val = obo_token_getter ()
90131 if obo_val :
91132 request .headers ['X-S2-OBO' ] = obo_val
92- if token :
93- request .headers ['Authorization' ] = f'Bearer { token } '
94133 request .headers .pop ('X-Amz-Date' , None )
95134 request .headers .pop ('X-Amz-Security-Token' , None )
96135
@@ -114,10 +153,29 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
114153 ** kwargs ,
115154 )
116155
156+ class OpenAIAuth (httpx .Auth ):
157+ def auth_flow (
158+ self , request : httpx .Request ,
159+ ) -> Generator [httpx .Request , None , None ]:
160+ if api_key_getter is not None :
161+ token_val = api_key_getter ()
162+ if token_val :
163+ request .headers ['Authorization' ] = f'Bearer { token_val } '
164+ if obo_token_getter is not None :
165+ obo_val = obo_token_getter ()
166+ if obo_val :
167+ request .headers ['X-S2-OBO' ] = obo_val
168+ yield request
169+
170+ http_client = httpx .Client (
171+ timeout = 30 ,
172+ auth = OpenAIAuth (),
173+ )
174+
117175 # OpenAI / Azure OpenAI path
118176 openai_kwargs = dict (
119177 base_url = info .connection_url ,
120- api_key = token ,
178+ api_key = 'placeholder' ,
121179 model = model_name ,
122180 )
123181 if http_client is not None :
0 commit comments