Skip to content

Commit 6ea4a6d

Browse files
feat: Introduce SingleStoreChat wrapper that uses interchangeably OpenAI or AmazonBedrockConverse protocol.
1 parent 8c869cf commit 6ea4a6d

File tree

1 file changed

+297
-5
lines changed

1 file changed

+297
-5
lines changed

singlestoredb/ai/chat.py

Lines changed: 297 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import os
2+
import uuid
23
from typing import Any
4+
from typing import AsyncIterator
5+
from typing import Optional
6+
7+
import httpx
38

49
from singlestoredb.fusion.handlers.utils import get_workspace_manager
510

@@ -8,33 +13,320 @@
813
except ImportError:
914
raise ImportError(
1015
'Could not import langchain_openai python package. '
11-
'Please install it with `pip install langchain_openai`.',
16+
'Please install it with `pip install langchain-openai`.',
17+
)
18+
19+
try:
20+
from langchain_aws import ChatBedrockConverse
21+
except ImportError:
22+
raise ImportError(
23+
'Could not import langchain-aws python package. '
24+
'Please install it with `pip install langchain-aws`.',
1225
)
1326

1427

1528
class SingleStoreChatOpenAI(ChatOpenAI):
16-
def __init__(self, model_name: str, **kwargs: Any):
29+
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
1730
inference_api_manger = (
1831
get_workspace_manager().organizations.current.inference_apis
1932
)
2033
info = inference_api_manger.get(model_name=model_name)
34+
token = (
35+
api_key
36+
if api_key is not None
37+
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
38+
)
2139
super().__init__(
2240
base_url=info.connection_url,
23-
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
41+
api_key=token,
2442
model=model_name,
2543
**kwargs,
2644
)
2745

2846

2947
class SingleStoreChat(ChatOpenAI):
30-
def __init__(self, model_name: str, **kwargs: Any):
48+
def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any):
3149
inference_api_manger = (
3250
get_workspace_manager().organizations.current.inference_apis
3351
)
3452
info = inference_api_manger.get(model_name=model_name)
53+
token = (
54+
api_key
55+
if api_key is not None
56+
else os.environ.get('SINGLESTOREDB_USER_TOKEN')
57+
)
3558
super().__init__(
3659
base_url=info.connection_url,
37-
api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'),
60+
api_key=token,
3861
model=model_name,
3962
**kwargs,
4063
)
64+
65+
66+
class SingleStoreExperimentalChat:
67+
"""Experimental unified chat interface (prefix-based two-part identifier).
68+
69+
Input model name MUST (for dynamic selection) be of the form:
70+
<prefix>.<actual_model>
71+
where <prefix> is one of:
72+
* ``aura`` -> OpenAI style (ChatOpenAI backend)
73+
* ``aura-azr`` -> Azure OpenAI style (still ChatOpenAI backend)
74+
* ``aura-amz`` -> Amazon Bedrock (ChatBedrockConverse backend)
75+
76+
If no delimiter (".", ":" or "/") is present, or prefix is unrecognized,
77+
the entire string is treated as an OpenAI-style model (ChatOpenAI).
78+
79+
Only the prefix ``aura-amz`` triggers Bedrock usage; in that case the
80+
*second* component (after the first delimiter) is passed as the model
81+
name to the Bedrock client. For other prefixes the second component is
82+
passed to ChatOpenAI with the SingleStore Fusion-provided base_url.
83+
84+
This class uses composition and delegates attribute access to the chosen
85+
backend client for near drop-in behavior.
86+
"""
87+
88+
_VALID_PREFIXES = {'aura', 'aura-azr', 'aura-amz'}
89+
90+
def __init__(
91+
self,
92+
model_name: str,
93+
http_client: Optional[httpx.Client] = None,
94+
api_key: Optional[str] = None,
95+
**kwargs: Any,
96+
) -> None:
97+
prefix, actual_model = self._parse_identifier(model_name)
98+
99+
inference_api_manager = (
100+
get_workspace_manager().organizations.current.inference_apis
101+
)
102+
# Use the raw identifier for Fusion lookup (keeps gateway mapping
103+
# logic server-side).
104+
info = inference_api_manager.get(model_name=actual_model)
105+
if prefix == 'aura-amz':
106+
backend_type = 'bedrock'
107+
elif prefix == 'aura-azr':
108+
backend_type = 'azure-openai'
109+
else:
110+
backend_type = 'openai'
111+
112+
# Extract headers from provided http_client (if any) for possible reuse.
113+
provided_headers: dict[str, str] = {}
114+
if http_client is not None and hasattr(http_client, 'headers'):
115+
try:
116+
provided_headers = dict(http_client.headers) # make a copy
117+
except Exception:
118+
provided_headers = {}
119+
120+
if backend_type == 'bedrock':
121+
self._removed_aws_env: dict[str, str] = {}
122+
for _v in (
123+
'AWS_ACCESS_KEY_ID',
124+
'AWS_SECRET_ACCESS_KEY',
125+
'AWS_SESSION_TOKEN',
126+
):
127+
if _v in os.environ:
128+
self._removed_aws_env[_v] = os.environ.pop(_v)
129+
130+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
131+
token = api_key if api_key is not None else token_env
132+
# Generate a per-instance client ID for tracing Bedrock calls.
133+
self._client_id = str(uuid.uuid4())
134+
self._client = ChatBedrockConverse(
135+
base_url=info.connection_url,
136+
model=actual_model,
137+
**kwargs,
138+
)
139+
140+
# Attempt to inject Authorization header for downstream HTTP layers.
141+
# Not all implementations expose a direct header map; we add a
142+
# lightweight wrapper if needed.
143+
self._auth_header = None
144+
merged_headers: dict[str, str] = {}
145+
if provided_headers:
146+
merged_headers.update({k: v for k, v in provided_headers.items()})
147+
if token:
148+
merged_headers.setdefault('Authorization', f'Bearer {token}')
149+
# Always include X-ClientID for Bedrock path
150+
merged_headers.setdefault('X-ClientID', self._client_id)
151+
if merged_headers:
152+
# Try to set directly if backend exposes default_headers
153+
if (
154+
hasattr(self._client, 'default_headers')
155+
and isinstance(
156+
getattr(self._client, 'default_headers'),
157+
dict,
158+
)
159+
):
160+
getattr(self._client, 'default_headers').update(
161+
{
162+
k: v
163+
for k, v in merged_headers.items()
164+
if k
165+
not in getattr(
166+
self._client, 'default_headers',
167+
)
168+
},
169+
)
170+
else:
171+
self._auth_header = merged_headers # fallback for invoke/stream
172+
else:
173+
# Pass through http_client if ChatOpenAI supports it; if not,
174+
# include in kwargs only when present.
175+
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
176+
token = api_key if api_key is not None else token_env
177+
openai_kwargs = dict(
178+
base_url=info.connection_url,
179+
api_key=token,
180+
model=actual_model,
181+
)
182+
if http_client is not None:
183+
# Some versions accept 'http_client' parameter for custom transport.
184+
openai_kwargs['http_client'] = http_client
185+
self._client = ChatOpenAI(
186+
**openai_kwargs,
187+
**kwargs,
188+
)
189+
190+
self._backend_type = backend_type
191+
self.model_name = model_name # external identifier provided by caller
192+
self.actual_model = actual_model # model portion after prefix
193+
self.prefix = prefix # normalized prefix
194+
self.connection_url = info.connection_url
195+
196+
@classmethod
197+
def _parse_identifier(cls, identifier: str) -> tuple[str, str]:
198+
for sep in ('.', ':', '/'):
199+
if sep in identifier:
200+
head, tail = identifier.split(sep, 1)
201+
prefix = head.strip().lower()
202+
model = tail.strip()
203+
if prefix in cls._VALID_PREFIXES:
204+
return prefix, model
205+
return 'aura', identifier.strip() # treat whole string as model
206+
return 'aura', identifier.strip()
207+
208+
# ---------------------------------------------------------------------
209+
# Delegation layer
210+
# ---------------------------------------------------------------------
211+
def __getattr__(self, item: str) -> Any:
212+
return getattr(self._client, item)
213+
214+
# ------------------------------------------------------------------
215+
# Internal helpers
216+
# ------------------------------------------------------------------
217+
def _maybe_inject_headers(self, kwargs: dict[str, Any]) -> None:
218+
"""Inject Bedrock auth headers into kwargs if we only have a fallback.
219+
220+
If the Bedrock client accepted headers via its own internal
221+
`default_headers` we don't need to do anything here. When we had
222+
to stash headers into `_auth_header` we add them for each outbound
223+
call that allows a `headers` kwarg and has not already provided
224+
its own.
225+
"""
226+
if (
227+
self._backend_type == 'bedrock'
228+
and hasattr(self, '_auth_header')
229+
and getattr(self, '_auth_header')
230+
and 'headers' not in kwargs
231+
):
232+
kwargs['headers'] = getattr(
233+
self,
234+
'_auth_header',
235+
)
236+
237+
def as_base(self) -> Any:
238+
"""Return the underlying backend client instance.
239+
240+
This gives callers direct access to provider specific methods or
241+
configuration that aren't surfaced by the experimental wrapper.
242+
"""
243+
return self._client
244+
245+
def invoke(self, *args: Any, **kwargs: Any) -> Any:
246+
self._maybe_inject_headers(kwargs)
247+
return self._client.invoke(*args, **kwargs)
248+
249+
async def ainvoke(self, *args: Any, **kwargs: Any) -> Any:
250+
self._maybe_inject_headers(kwargs)
251+
return await self._client.ainvoke(*args, **kwargs)
252+
253+
def stream(self, *args: Any, **kwargs: Any) -> Any:
254+
self._maybe_inject_headers(kwargs)
255+
return self._client.stream(*args, **kwargs)
256+
257+
async def astream(
258+
self,
259+
*args: Any,
260+
**kwargs: Any,
261+
) -> AsyncIterator[Any]:
262+
self._maybe_inject_headers(kwargs)
263+
async for chunk in self._client.astream(*args, **kwargs):
264+
yield chunk
265+
266+
# ------------------------------------------------------------------
267+
# Extended delegation for additional common chat model surface area.
268+
# Each method simply injects headers (if needed) then forwards.
269+
# ------------------------------------------------------------------
270+
def generate(self, *args: Any, **kwargs: Any) -> Any:
271+
self._maybe_inject_headers(kwargs)
272+
return self._client.generate(*args, **kwargs)
273+
274+
async def agenerate(self, *args: Any, **kwargs: Any) -> Any:
275+
self._maybe_inject_headers(kwargs)
276+
return await self._client.agenerate(*args, **kwargs)
277+
278+
def predict(self, *args: Any, **kwargs: Any) -> Any:
279+
self._maybe_inject_headers(kwargs)
280+
return self._client.predict(*args, **kwargs)
281+
282+
async def apredict(
283+
self,
284+
*args: Any,
285+
**kwargs: Any,
286+
) -> Any:
287+
self._maybe_inject_headers(kwargs)
288+
return await self._client.apredict(*args, **kwargs)
289+
290+
def predict_messages(
291+
self,
292+
*args: Any,
293+
**kwargs: Any,
294+
) -> Any:
295+
self._maybe_inject_headers(kwargs)
296+
return self._client.predict_messages(*args, **kwargs)
297+
298+
async def apredict_messages(
299+
self,
300+
*args: Any,
301+
**kwargs: Any,
302+
) -> Any:
303+
self._maybe_inject_headers(kwargs)
304+
return await self._client.apredict_messages(*args, **kwargs)
305+
306+
def batch(self, *args: Any, **kwargs: Any) -> Any:
307+
self._maybe_inject_headers(kwargs)
308+
return self._client.batch(*args, **kwargs)
309+
310+
async def abatch(self, *args: Any, **kwargs: Any) -> Any:
311+
self._maybe_inject_headers(kwargs)
312+
return await self._client.abatch(*args, **kwargs)
313+
314+
def apply(self, *args: Any, **kwargs: Any) -> Any:
315+
self._maybe_inject_headers(kwargs)
316+
return self._client.apply(*args, **kwargs)
317+
318+
async def aapply(
319+
self,
320+
*args: Any,
321+
**kwargs: Any,
322+
) -> Any:
323+
self._maybe_inject_headers(kwargs)
324+
return await self._client.aapply(*args, **kwargs)
325+
326+
def __repr__(self) -> str:
327+
return (
328+
'SingleStoreExperimentalChat('
329+
f'identifier={self.model_name!r}, '
330+
f'actual_model={self.actual_model!r}, '
331+
f'prefix={self.prefix}, backend={self._backend_type})'
332+
)

0 commit comments

Comments
 (0)