11import os
22from typing import Any
3+ from typing import Callable
4+ from typing import Optional
5+ from typing import Union
6+
7+ import httpx
38
49from singlestoredb .fusion .handlers .utils import get_workspace_manager
510
1116 'Please install it with `pip install langchain_openai`.' ,
1217 )
1318
19+ try :
20+ from langchain_aws import BedrockEmbeddings
21+ except ImportError :
22+ raise ImportError (
23+ 'Could not import langchain-aws python package. '
24+ 'Please install it with `pip install langchain-aws`.' ,
25+ )
26+
27+ import boto3
28+ from botocore import UNSIGNED
29+ from botocore .config import Config
30+
1431
1532class SingleStoreEmbeddings (OpenAIEmbeddings ):
1633
@@ -25,3 +42,84 @@ def __init__(self, model_name: str, **kwargs: Any):
2542 model = model_name ,
2643 ** kwargs ,
2744 )
45+
46+
47+ def SingleStoreEmbeddingsFactory (
48+ model_name : str ,
49+ api_key : Optional [str ] = None ,
50+ http_client : Optional [httpx .Client ] = None ,
51+ obo_token_getter : Optional [Callable [[], Optional [str ]]] = None ,
52+ ** kwargs : Any ,
53+ ) -> Union [OpenAIEmbeddings , BedrockEmbeddings ]:
54+ """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
55+ """
56+ inference_api_manager = (
57+ get_workspace_manager ().organizations .current .inference_apis
58+ )
59+ info = inference_api_manager .get (model_name = model_name )
60+ token_env = os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
61+ token = api_key if api_key is not None else token_env
62+
63+ if info .hosting_platform == 'Amazon' :
64+ # Instantiate Bedrock client
65+ cfg_kwargs = {
66+ 'signature_version' : UNSIGNED ,
67+ 'retries' : {'max_attempts' : 1 , 'mode' : 'standard' },
68+ }
69+ if http_client is not None and http_client .timeout is not None :
70+ cfg_kwargs ['read_timeout' ] = http_client .timeout
71+ cfg_kwargs ['connect_timeout' ] = http_client .timeout
72+
73+ cfg = Config (** cfg_kwargs )
74+ client = boto3 .client (
75+ 'bedrock-runtime' ,
76+ endpoint_url = info .connection_url , # redirect requests to UMG
77+ region_name = 'us-east-1' , # dummy value; UMG does not use this
78+ aws_access_key_id = 'placeholder' , # dummy value; UMG does not use this
79+ aws_secret_access_key = 'placeholder' , # dummy value; UMG does not use this
80+ config = cfg ,
81+ )
82+
83+ def _inject_headers (request : Any , ** _ignored : Any ) -> None :
84+ """Inject dynamic auth/OBO headers prior to Bedrock sending."""
85+ if obo_token_getter is not None :
86+ obo_val = obo_token_getter ()
87+ if obo_val :
88+ request .headers ['X-S2-OBO' ] = obo_val
89+ if token :
90+ request .headers ['Authorization' ] = f'Bearer { token } '
91+ request .headers .pop ('X-Amz-Date' , None )
92+ request .headers .pop ('X-Amz-Security-Token' , None )
93+
94+ emitter = client ._endpoint ._event_emitter
95+ emitter .register_first (
96+ 'before-send.bedrock-runtime.InvokeModel' ,
97+ _inject_headers ,
98+ )
99+ emitter .register_first (
100+ 'before-send.bedrock-runtime.InvokeModelWithResponseStream' ,
101+ _inject_headers ,
102+ )
103+
104+ return BedrockEmbeddings (
105+ model_id = model_name ,
106+ endpoint_url = info .connection_url , # redirect requests to UMG
107+ region_name = 'us-east-1' , # dummy value; UMG does not use this
108+ aws_access_key_id = 'placeholder' , # dummy value; UMG does not use this
109+ aws_secret_access_key = 'placeholder' , # dummy value; UMG does not use this
110+ client = client ,
111+ ** kwargs ,
112+ )
113+
114+ # OpenAI / Azure OpenAI path
115+ openai_kwargs = dict (
116+ base_url = info .connection_url ,
117+ api_key = token ,
118+ model = model_name ,
119+ )
120+ if http_client is not None :
121+ openai_kwargs ['http_client' ] = http_client
122+ return OpenAIEmbeddings (
123+ ** openai_kwargs ,
124+ ** kwargs ,
125+ )
0 commit comments