Skip to content

Commit 40e12c2

Browse files
Introduce new factory methods for testing purposes.
1 parent 3c7a979 commit 40e12c2

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

singlestoredb/ai/chat.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,113 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
175175
**openai_kwargs,
176176
**kwargs,
177177
)
178+
179+
180+
def NewSingleStoreChatFactory(
181+
model_name: str,
182+
api_key: Optional[str] = None,
183+
streaming: bool = True,
184+
http_client: Optional[httpx.Client] = None,
185+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
186+
**kwargs: Any,
187+
) -> Union[ChatOpenAI, ChatBedrockConverse]:
188+
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
189+
"""
190+
inference_api_manager = (
191+
manage_workspaces().organizations.current.inference_apis
192+
)
193+
info = inference_api_manager.get(model_name=model_name)
194+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
195+
token = api_key if api_key is not None else token_env
196+
197+
if info.hosting_platform == 'Amazon':
198+
# Instantiate Bedrock client
199+
cfg_kwargs = {
200+
'signature_version': UNSIGNED,
201+
'retries': {'max_attempts': 1, 'mode': 'standard'},
202+
}
203+
# Extract timeouts from http_client if provided
204+
t = http_client.timeout if http_client is not None else None
205+
connect_timeout = None
206+
read_timeout = None
207+
if t is not None:
208+
if isinstance(t, httpx.Timeout):
209+
if t.connect is not None:
210+
connect_timeout = float(t.connect)
211+
if t.read is not None:
212+
read_timeout = float(t.read)
213+
if connect_timeout is None and read_timeout is not None:
214+
connect_timeout = read_timeout
215+
if read_timeout is None and connect_timeout is not None:
216+
read_timeout = connect_timeout
217+
elif isinstance(t, (int, float)):
218+
connect_timeout = float(t)
219+
read_timeout = float(t)
220+
if read_timeout is not None:
221+
cfg_kwargs['read_timeout'] = read_timeout
222+
if connect_timeout is not None:
223+
cfg_kwargs['connect_timeout'] = connect_timeout
224+
225+
cfg = Config(**cfg_kwargs)
226+
client = boto3.client(
227+
'bedrock-runtime',
228+
endpoint_url=info.connection_url,
229+
region_name='us-east-1',
230+
aws_access_key_id='placeholder',
231+
aws_secret_access_key='placeholder',
232+
config=cfg,
233+
)
234+
235+
def _inject_headers(request: Any, **_ignored: Any) -> None:
236+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
237+
if obo_token_getter is not None:
238+
obo_val = obo_token_getter()
239+
if obo_val:
240+
request.headers['X-S2-OBO'] = obo_val
241+
if token:
242+
request.headers['Authorization'] = f'Bearer {token}'
243+
request.headers.pop('X-Amz-Date', None)
244+
request.headers.pop('X-Amz-Security-Token', None)
245+
246+
emitter = client._endpoint._event_emitter
247+
emitter.register_first(
248+
'before-send.bedrock-runtime.Converse',
249+
_inject_headers,
250+
)
251+
emitter.register_first(
252+
'before-send.bedrock-runtime.ConverseStream',
253+
_inject_headers,
254+
)
255+
emitter.register_first(
256+
'before-send.bedrock-runtime.InvokeModel',
257+
_inject_headers,
258+
)
259+
emitter.register_first(
260+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
261+
_inject_headers,
262+
)
263+
264+
return ChatBedrockConverse(
265+
model_id=model_name,
266+
endpoint_url=info.connection_url,
267+
region_name='us-east-1',
268+
aws_access_key_id='placeholder',
269+
aws_secret_access_key='placeholder',
270+
disable_streaming=not streaming,
271+
client=client,
272+
**kwargs,
273+
)
274+
275+
# OpenAI / Azure OpenAI path
276+
openai_kwargs = dict(
277+
base_url=info.connection_url,
278+
api_key=token,
279+
model=model_name,
280+
streaming=streaming,
281+
)
282+
if http_client is not None:
283+
openai_kwargs['http_client'] = http_client
284+
return ChatOpenAI(
285+
**openai_kwargs,
286+
**kwargs,
287+
)

singlestoredb/ai/embeddings.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,102 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
141141
**openai_kwargs,
142142
**kwargs,
143143
)
144+
145+
146+
def NewSingleStoreEmbeddingsFactory(
147+
model_name: str,
148+
api_key: Optional[str] = None,
149+
http_client: Optional[httpx.Client] = None,
150+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
151+
**kwargs: Any,
152+
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
153+
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
154+
"""
155+
inference_api_manager = (
156+
manage_workspaces().organizations.current.inference_apis
157+
)
158+
info = inference_api_manager.get(model_name=model_name)
159+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
160+
token = api_key if api_key is not None else token_env
161+
162+
if info.hosting_platform == 'Amazon':
163+
# Instantiate Bedrock client
164+
cfg_kwargs = {
165+
'signature_version': UNSIGNED,
166+
'retries': {'max_attempts': 1, 'mode': 'standard'},
167+
}
168+
# Extract timeouts from http_client if provided
169+
t = http_client.timeout if http_client is not None else None
170+
connect_timeout = None
171+
read_timeout = None
172+
if t is not None:
173+
if isinstance(t, httpx.Timeout):
174+
if t.connect is not None:
175+
connect_timeout = float(t.connect)
176+
if t.read is not None:
177+
read_timeout = float(t.read)
178+
if connect_timeout is None and read_timeout is not None:
179+
connect_timeout = read_timeout
180+
if read_timeout is None and connect_timeout is not None:
181+
read_timeout = connect_timeout
182+
elif isinstance(t, (int, float)):
183+
connect_timeout = float(t)
184+
read_timeout = float(t)
185+
if read_timeout is not None:
186+
cfg_kwargs['read_timeout'] = read_timeout
187+
if connect_timeout is not None:
188+
cfg_kwargs['connect_timeout'] = connect_timeout
189+
190+
cfg = Config(**cfg_kwargs)
191+
client = boto3.client(
192+
'bedrock-runtime',
193+
endpoint_url=info.connection_url,
194+
region_name='us-east-1',
195+
aws_access_key_id='placeholder',
196+
aws_secret_access_key='placeholder',
197+
config=cfg,
198+
)
199+
200+
def _inject_headers(request: Any, **_ignored: Any) -> None:
201+
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
202+
if obo_token_getter is not None:
203+
obo_val = obo_token_getter()
204+
if obo_val:
205+
request.headers['X-S2-OBO'] = obo_val
206+
if token:
207+
request.headers['Authorization'] = f'Bearer {token}'
208+
request.headers.pop('X-Amz-Date', None)
209+
request.headers.pop('X-Amz-Security-Token', None)
210+
211+
emitter = client._endpoint._event_emitter
212+
emitter.register_first(
213+
'before-send.bedrock-runtime.InvokeModel',
214+
_inject_headers,
215+
)
216+
emitter.register_first(
217+
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
218+
_inject_headers,
219+
)
220+
221+
return BedrockEmbeddings(
222+
model_id=model_name,
223+
endpoint_url=info.connection_url,
224+
region_name='us-east-1',
225+
aws_access_key_id='placeholder',
226+
aws_secret_access_key='placeholder',
227+
client=client,
228+
**kwargs,
229+
)
230+
231+
# OpenAI / Azure OpenAI path
232+
openai_kwargs = dict(
233+
base_url=info.connection_url,
234+
api_key=token,
235+
model=model_name,
236+
)
237+
if http_client is not None:
238+
openai_kwargs['http_client'] = http_client
239+
return OpenAIEmbeddings(
240+
**openai_kwargs,
241+
**kwargs,
242+
)

0 commit comments

Comments
 (0)