Skip to content

Commit 054c8e2

Browse files
Introduce SingleStoreEmbeddingsFactory; small fixes.
1 parent 2fdd49b commit 054c8e2

File tree

2 files changed

+134
-35
lines changed

2 files changed

+134
-35
lines changed

singlestoredb/ai/chat.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,15 @@ def SingleStoreChatFactory(
8686

8787
if info.hosting_platform == 'Amazon':
8888
# Instantiate Bedrock client
89-
cfg = Config(
90-
signature_version=UNSIGNED,
91-
retries={
92-
'max_attempts': 1,
93-
'mode': 'standard',
94-
},
95-
)
89+
cfg_kwargs = {
90+
'signature_version': UNSIGNED,
91+
'retries': {'max_attempts': 1, 'mode': 'standard'},
92+
}
9693
if http_client is not None and http_client.timeout is not None:
97-
cfg.timeout = http_client.timeout
98-
cfg.connect_timeout = http_client.timeout
94+
cfg_kwargs['read_timeout'] = http_client.timeout
95+
cfg_kwargs['connect_timeout'] = http_client.timeout
96+
97+
cfg = Config(**cfg_kwargs)
9998
client = boto3.client(
10099
'bedrock-runtime',
101100
endpoint_url=info.connection_url, # redirect requests to UMG
@@ -104,36 +103,38 @@ def SingleStoreChatFactory(
104103
aws_secret_access_key='placeholder', # dummy value; UMG does not use this
105104
config=cfg,
106105
)
107-
if obo_token_getter is not None:
108-
def _inject_headers(request: Any, **_ignored: Any) -> None:
109-
"""Inject dynamic auth/OBO headers prior to Bedrock signing."""
106+
107+
def _inject_headers(request: Any, **_ignored: Any) -> None:
108+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
109+
if obo_token_getter is not None:
110110
obo_val = obo_token_getter()
111111
if obo_val:
112112
request.headers['X-S2-OBO'] = obo_val
113-
if token:
114-
request.headers['Authorization'] = f'Bearer {token}'
115-
request.headers.pop('X-Amz-Date', None)
116-
request.headers.pop('X-Amz-Security-Token', None)
117-
118-
emitter = client._endpoint._event_emitter
119-
emitter.register_first(
120-
'before-send.bedrock-runtime.Converse',
121-
_inject_headers,
122-
)
123-
emitter.register_first(
124-
'before-send.bedrock-runtime.ConverseStream',
125-
_inject_headers,
126-
)
127-
emitter.register_first(
128-
'before-send.bedrock-runtime.InvokeModel',
129-
_inject_headers,
130-
)
131-
emitter.register_first(
132-
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
133-
_inject_headers,
134-
)
113+
if token:
114+
request.headers['Authorization'] = f'Bearer {token}'
115+
request.headers.pop('X-Amz-Date', None)
116+
request.headers.pop('X-Amz-Security-Token', None)
117+
118+
emitter = client._endpoint._event_emitter
119+
emitter.register_first(
120+
'before-send.bedrock-runtime.Converse',
121+
_inject_headers,
122+
)
123+
emitter.register_first(
124+
'before-send.bedrock-runtime.ConverseStream',
125+
_inject_headers,
126+
)
127+
emitter.register_first(
128+
'before-send.bedrock-runtime.InvokeModel',
129+
_inject_headers,
130+
)
131+
emitter.register_first(
132+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
133+
_inject_headers,
134+
)
135+
135136
return ChatBedrockConverse(
136-
model=model_name,
137+
model_id=model_name,
137138
endpoint_url=info.connection_url, # redirect requests to UMG
138139
region_name='us-east-1', # dummy value; UMG does not use this
139140
aws_access_key_id='placeholder', # dummy value; UMG does not use this

singlestoredb/ai/embeddings.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import os
22
from typing import Any
3+
from typing import Callable
4+
from typing import Optional
5+
from typing import Union
6+
7+
import httpx
38

49
from singlestoredb.fusion.handlers.utils import get_workspace_manager
510

@@ -11,6 +16,18 @@
1116
'Please install it with `pip install langchain_openai`.',
1217
)
1318

19+
try:
20+
from langchain_aws import BedrockEmbeddings
21+
except ImportError:
22+
raise ImportError(
23+
'Could not import langchain-aws python package. '
24+
'Please install it with `pip install langchain-aws`.',
25+
)
26+
27+
import boto3
28+
from botocore import UNSIGNED
29+
from botocore.config import Config
30+
1431

1532
class SingleStoreEmbeddings(OpenAIEmbeddings):
1633

@@ -25,3 +42,84 @@ def __init__(self, model_name: str, **kwargs: Any):
2542
model=model_name,
2643
**kwargs,
2744
)
45+
46+
47+
def SingleStoreEmbeddingsFactory(
48+
model_name: str,
49+
api_key: Optional[str] = None,
50+
http_client: Optional[httpx.Client] = None,
51+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
52+
**kwargs: Any,
53+
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
54+
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
55+
"""
56+
inference_api_manager = (
57+
get_workspace_manager().organizations.current.inference_apis
58+
)
59+
info = inference_api_manager.get(model_name=model_name)
60+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
61+
token = api_key if api_key is not None else token_env
62+
63+
if info.hosting_platform == 'Amazon':
64+
# Instantiate Bedrock client
65+
cfg_kwargs = {
66+
'signature_version': UNSIGNED,
67+
'retries': {'max_attempts': 1, 'mode': 'standard'},
68+
}
69+
if http_client is not None and http_client.timeout is not None:
70+
cfg_kwargs['read_timeout'] = http_client.timeout
71+
cfg_kwargs['connect_timeout'] = http_client.timeout
72+
73+
cfg = Config(**cfg_kwargs)
74+
client = boto3.client(
75+
'bedrock-runtime',
76+
endpoint_url=info.connection_url, # redirect requests to UMG
77+
region_name='us-east-1', # dummy value; UMG does not use this
78+
aws_access_key_id='placeholder', # dummy value; UMG does not use this
79+
aws_secret_access_key='placeholder', # dummy value; UMG does not use this
80+
config=cfg,
81+
)
82+
83+
def _inject_headers(request: Any, **_ignored: Any) -> None:
84+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
85+
if obo_token_getter is not None:
86+
obo_val = obo_token_getter()
87+
if obo_val:
88+
request.headers['X-S2-OBO'] = obo_val
89+
if token:
90+
request.headers['Authorization'] = f'Bearer {token}'
91+
request.headers.pop('X-Amz-Date', None)
92+
request.headers.pop('X-Amz-Security-Token', None)
93+
94+
emitter = client._endpoint._event_emitter
95+
emitter.register_first(
96+
'before-send.bedrock-runtime.InvokeModel',
97+
_inject_headers,
98+
)
99+
emitter.register_first(
100+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
101+
_inject_headers,
102+
)
103+
104+
return BedrockEmbeddings(
105+
model_id=model_name,
106+
endpoint_url=info.connection_url, # redirect requests to UMG
107+
region_name='us-east-1', # dummy value; UMG does not use this
108+
aws_access_key_id='placeholder', # dummy value; UMG does not use this
109+
aws_secret_access_key='placeholder', # dummy value; UMG does not use this
110+
client=client,
111+
**kwargs,
112+
)
113+
114+
# OpenAI / Azure OpenAI path
115+
openai_kwargs = dict(
116+
base_url=info.connection_url,
117+
api_key=token,
118+
model=model_name,
119+
)
120+
if http_client is not None:
121+
openai_kwargs['http_client'] = http_client
122+
return OpenAIEmbeddings(
123+
**openai_kwargs,
124+
**kwargs,
125+
)

0 commit comments

Comments
 (0)