|
1 | 1 | import os |
| 2 | +import uuid |
2 | 3 | from typing import Any |
| 4 | +from typing import AsyncIterator |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import httpx |
3 | 8 |
|
4 | 9 | from singlestoredb.fusion.handlers.utils import get_workspace_manager |
5 | 10 |
|
|
8 | 13 | except ImportError: |
9 | 14 | raise ImportError( |
10 | 15 | 'Could not import langchain_openai python package. ' |
11 | | - 'Please install it with `pip install langchain_openai`.', |
| 16 | + 'Please install it with `pip install langchain-openai`.', |
| 17 | + ) |
| 18 | + |
| 19 | +try: |
| 20 | + from langchain_aws import ChatBedrockConverse |
| 21 | +except ImportError: |
| 22 | + raise ImportError( |
| 23 | + 'Could not import langchain-aws python package. ' |
| 24 | + 'Please install it with `pip install langchain-aws`.', |
12 | 25 | ) |
13 | 26 |
|
14 | 27 |
|
15 | 28 | class SingleStoreChatOpenAI(ChatOpenAI): |
16 | | - def __init__(self, model_name: str, **kwargs: Any): |
| 29 | + def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any): |
17 | 30 | inference_api_manger = ( |
18 | 31 | get_workspace_manager().organizations.current.inference_apis |
19 | 32 | ) |
20 | 33 | info = inference_api_manger.get(model_name=model_name) |
| 34 | + token = ( |
| 35 | + api_key |
| 36 | + if api_key is not None |
| 37 | + else os.environ.get('SINGLESTOREDB_USER_TOKEN') |
| 38 | + ) |
21 | 39 | super().__init__( |
22 | 40 | base_url=info.connection_url, |
23 | | - api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'), |
| 41 | + api_key=token, |
24 | 42 | model=model_name, |
25 | 43 | **kwargs, |
26 | 44 | ) |
27 | 45 |
|
28 | 46 |
|
29 | 47 | class SingleStoreChat(ChatOpenAI): |
30 | | - def __init__(self, model_name: str, **kwargs: Any): |
| 48 | + def __init__(self, model_name: str, api_key: Optional[str] = None, **kwargs: Any): |
31 | 49 | inference_api_manger = ( |
32 | 50 | get_workspace_manager().organizations.current.inference_apis |
33 | 51 | ) |
34 | 52 | info = inference_api_manger.get(model_name=model_name) |
| 53 | + token = ( |
| 54 | + api_key |
| 55 | + if api_key is not None |
| 56 | + else os.environ.get('SINGLESTOREDB_USER_TOKEN') |
| 57 | + ) |
35 | 58 | super().__init__( |
36 | 59 | base_url=info.connection_url, |
37 | | - api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'), |
| 60 | + api_key=token, |
38 | 61 | model=model_name, |
39 | 62 | **kwargs, |
40 | 63 | ) |
| 64 | + |
| 65 | + |
| 66 | +class SingleStoreExperimentalChat: |
| 67 | + """Experimental unified chat interface (prefix-based two-part identifier). |
| 68 | +
|
| 69 | + Input model name MUST (for dynamic selection) be of the form: |
| 70 | + <prefix>.<actual_model> |
| 71 | + where <prefix> is one of: |
| 72 | + * ``aura`` -> OpenAI style (ChatOpenAI backend) |
| 73 | + * ``aura-azr`` -> Azure OpenAI style (still ChatOpenAI backend) |
| 74 | + * ``aura-amz`` -> Amazon Bedrock (ChatBedrockConverse backend) |
| 75 | +
|
| 76 | + If no delimiter (".", ":" or "/") is present, or prefix is unrecognized, |
| 77 | + the entire string is treated as an OpenAI-style model (ChatOpenAI). |
| 78 | +
|
| 79 | + Only the prefix ``aura-amz`` triggers Bedrock usage; in that case the |
| 80 | + *second* component (after the first delimiter) is passed as the model |
| 81 | + name to the Bedrock client. For other prefixes the second component is |
| 82 | + passed to ChatOpenAI with the SingleStore Fusion-provided base_url. |
| 83 | +
|
| 84 | + This class uses composition and delegates attribute access to the chosen |
| 85 | + backend client for near drop-in behavior. |
| 86 | + """ |
| 87 | + |
| 88 | + _VALID_PREFIXES = {'aura', 'aura-azr', 'aura-amz'} |
| 89 | + |
| 90 | + def __init__( |
| 91 | + self, |
| 92 | + model_name: str, |
| 93 | + http_client: Optional[httpx.Client] = None, |
| 94 | + api_key: Optional[str] = None, |
| 95 | + **kwargs: Any, |
| 96 | + ) -> None: |
| 97 | + prefix, actual_model = self._parse_identifier(model_name) |
| 98 | + |
| 99 | + inference_api_manager = ( |
| 100 | + get_workspace_manager().organizations.current.inference_apis |
| 101 | + ) |
| 102 | + # Use the raw identifier for Fusion lookup (keeps gateway mapping |
| 103 | + # logic server-side). |
| 104 | + info = inference_api_manager.get(model_name=actual_model) |
| 105 | + if prefix == 'aura-amz': |
| 106 | + backend_type = 'bedrock' |
| 107 | + elif prefix == 'aura-azr': |
| 108 | + backend_type = 'azure-openai' |
| 109 | + else: |
| 110 | + backend_type = 'openai' |
| 111 | + |
| 112 | + # Extract headers from provided http_client (if any) for possible reuse. |
| 113 | + provided_headers: dict[str, str] = {} |
| 114 | + if http_client is not None and hasattr(http_client, 'headers'): |
| 115 | + try: |
| 116 | + provided_headers = dict(http_client.headers) # make a copy |
| 117 | + except Exception: |
| 118 | + provided_headers = {} |
| 119 | + |
| 120 | + if backend_type == 'bedrock': |
| 121 | + self._removed_aws_env: dict[str, str] = {} |
| 122 | + for _v in ( |
| 123 | + 'AWS_ACCESS_KEY_ID', |
| 124 | + 'AWS_SECRET_ACCESS_KEY', |
| 125 | + 'AWS_SESSION_TOKEN', |
| 126 | + ): |
| 127 | + if _v in os.environ: |
| 128 | + self._removed_aws_env[_v] = os.environ.pop(_v) |
| 129 | + |
| 130 | + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') |
| 131 | + token = api_key if api_key is not None else token_env |
| 132 | + # Generate a per-instance client ID for tracing Bedrock calls. |
| 133 | + self._client_id = str(uuid.uuid4()) |
| 134 | + self._client = ChatBedrockConverse( |
| 135 | + base_url=info.connection_url, |
| 136 | + model=actual_model, |
| 137 | + **kwargs, |
| 138 | + ) |
| 139 | + |
| 140 | + # Attempt to inject Authorization header for downstream HTTP layers. |
| 141 | + # Not all implementations expose a direct header map; we add a |
| 142 | + # lightweight wrapper if needed. |
| 143 | + self._auth_header = None |
| 144 | + merged_headers: dict[str, str] = {} |
| 145 | + if provided_headers: |
| 146 | + merged_headers.update({k: v for k, v in provided_headers.items()}) |
| 147 | + if token: |
| 148 | + merged_headers.setdefault('Authorization', f'Bearer {token}') |
| 149 | + # Always include X-ClientID for Bedrock path |
| 150 | + merged_headers.setdefault('X-ClientID', self._client_id) |
| 151 | + if merged_headers: |
| 152 | + # Try to set directly if backend exposes default_headers |
| 153 | + if ( |
| 154 | + hasattr(self._client, 'default_headers') |
| 155 | + and isinstance( |
| 156 | + getattr(self._client, 'default_headers'), |
| 157 | + dict, |
| 158 | + ) |
| 159 | + ): |
| 160 | + getattr(self._client, 'default_headers').update( |
| 161 | + { |
| 162 | + k: v |
| 163 | + for k, v in merged_headers.items() |
| 164 | + if k |
| 165 | + not in getattr( |
| 166 | + self._client, 'default_headers', |
| 167 | + ) |
| 168 | + }, |
| 169 | + ) |
| 170 | + else: |
| 171 | + self._auth_header = merged_headers # fallback for invoke/stream |
| 172 | + else: |
| 173 | + # Pass through http_client if ChatOpenAI supports it; if not, |
| 174 | + # include in kwargs only when present. |
| 175 | + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') |
| 176 | + token = api_key if api_key is not None else token_env |
| 177 | + openai_kwargs = dict( |
| 178 | + base_url=info.connection_url, |
| 179 | + api_key=token, |
| 180 | + model=actual_model, |
| 181 | + ) |
| 182 | + if http_client is not None: |
| 183 | + # Some versions accept 'http_client' parameter for custom transport. |
| 184 | + openai_kwargs['http_client'] = http_client |
| 185 | + self._client = ChatOpenAI( |
| 186 | + **openai_kwargs, |
| 187 | + **kwargs, |
| 188 | + ) |
| 189 | + |
| 190 | + self._backend_type = backend_type |
| 191 | + self.model_name = model_name # external identifier provided by caller |
| 192 | + self.actual_model = actual_model # model portion after prefix |
| 193 | + self.prefix = prefix # normalized prefix |
| 194 | + self.connection_url = info.connection_url |
| 195 | + |
| 196 | + @classmethod |
| 197 | + def _parse_identifier(cls, identifier: str) -> tuple[str, str]: |
| 198 | + for sep in ('.', ':', '/'): |
| 199 | + if sep in identifier: |
| 200 | + head, tail = identifier.split(sep, 1) |
| 201 | + prefix = head.strip().lower() |
| 202 | + model = tail.strip() |
| 203 | + if prefix in cls._VALID_PREFIXES: |
| 204 | + return prefix, model |
| 205 | + return 'aura', identifier.strip() # treat whole string as model |
| 206 | + return 'aura', identifier.strip() |
| 207 | + |
| 208 | + # --------------------------------------------------------------------- |
| 209 | + # Delegation layer |
| 210 | + # --------------------------------------------------------------------- |
| 211 | + def __getattr__(self, item: str) -> Any: |
| 212 | + return getattr(self._client, item) |
| 213 | + |
| 214 | + # ------------------------------------------------------------------ |
| 215 | + # Internal helpers |
| 216 | + # ------------------------------------------------------------------ |
| 217 | + def _maybe_inject_headers(self, kwargs: dict[str, Any]) -> None: |
| 218 | + """Inject Bedrock auth headers into kwargs if we only have a fallback. |
| 219 | +
|
| 220 | + If the Bedrock client accepted headers via its own internal |
| 221 | + `default_headers` we don't need to do anything here. When we had |
| 222 | + to stash headers into `_auth_header` we add them for each outbound |
| 223 | + call that allows a `headers` kwarg and has not already provided |
| 224 | + its own. |
| 225 | + """ |
| 226 | + if ( |
| 227 | + self._backend_type == 'bedrock' |
| 228 | + and hasattr(self, '_auth_header') |
| 229 | + and getattr(self, '_auth_header') |
| 230 | + and 'headers' not in kwargs |
| 231 | + ): |
| 232 | + kwargs['headers'] = getattr( |
| 233 | + self, |
| 234 | + '_auth_header', |
| 235 | + ) |
| 236 | + |
| 237 | + def as_base(self) -> Any: |
| 238 | + """Return the underlying backend client instance. |
| 239 | +
|
| 240 | + This gives callers direct access to provider specific methods or |
| 241 | + configuration that aren't surfaced by the experimental wrapper. |
| 242 | + """ |
| 243 | + return self._client |
| 244 | + |
| 245 | + def invoke(self, *args: Any, **kwargs: Any) -> Any: |
| 246 | + self._maybe_inject_headers(kwargs) |
| 247 | + return self._client.invoke(*args, **kwargs) |
| 248 | + |
| 249 | + async def ainvoke(self, *args: Any, **kwargs: Any) -> Any: |
| 250 | + self._maybe_inject_headers(kwargs) |
| 251 | + return await self._client.ainvoke(*args, **kwargs) |
| 252 | + |
| 253 | + def stream(self, *args: Any, **kwargs: Any) -> Any: |
| 254 | + self._maybe_inject_headers(kwargs) |
| 255 | + return self._client.stream(*args, **kwargs) |
| 256 | + |
| 257 | + async def astream( |
| 258 | + self, |
| 259 | + *args: Any, |
| 260 | + **kwargs: Any, |
| 261 | + ) -> AsyncIterator[Any]: |
| 262 | + self._maybe_inject_headers(kwargs) |
| 263 | + async for chunk in self._client.astream(*args, **kwargs): |
| 264 | + yield chunk |
| 265 | + |
| 266 | + # ------------------------------------------------------------------ |
| 267 | + # Extended delegation for additional common chat model surface area. |
| 268 | + # Each method simply injects headers (if needed) then forwards. |
| 269 | + # ------------------------------------------------------------------ |
| 270 | + def generate(self, *args: Any, **kwargs: Any) -> Any: |
| 271 | + self._maybe_inject_headers(kwargs) |
| 272 | + return self._client.generate(*args, **kwargs) |
| 273 | + |
| 274 | + async def agenerate(self, *args: Any, **kwargs: Any) -> Any: |
| 275 | + self._maybe_inject_headers(kwargs) |
| 276 | + return await self._client.agenerate(*args, **kwargs) |
| 277 | + |
| 278 | + def predict(self, *args: Any, **kwargs: Any) -> Any: |
| 279 | + self._maybe_inject_headers(kwargs) |
| 280 | + return self._client.predict(*args, **kwargs) |
| 281 | + |
| 282 | + async def apredict( |
| 283 | + self, |
| 284 | + *args: Any, |
| 285 | + **kwargs: Any, |
| 286 | + ) -> Any: |
| 287 | + self._maybe_inject_headers(kwargs) |
| 288 | + return await self._client.apredict(*args, **kwargs) |
| 289 | + |
| 290 | + def predict_messages( |
| 291 | + self, |
| 292 | + *args: Any, |
| 293 | + **kwargs: Any, |
| 294 | + ) -> Any: |
| 295 | + self._maybe_inject_headers(kwargs) |
| 296 | + return self._client.predict_messages(*args, **kwargs) |
| 297 | + |
| 298 | + async def apredict_messages( |
| 299 | + self, |
| 300 | + *args: Any, |
| 301 | + **kwargs: Any, |
| 302 | + ) -> Any: |
| 303 | + self._maybe_inject_headers(kwargs) |
| 304 | + return await self._client.apredict_messages(*args, **kwargs) |
| 305 | + |
| 306 | + def batch(self, *args: Any, **kwargs: Any) -> Any: |
| 307 | + self._maybe_inject_headers(kwargs) |
| 308 | + return self._client.batch(*args, **kwargs) |
| 309 | + |
| 310 | + async def abatch(self, *args: Any, **kwargs: Any) -> Any: |
| 311 | + self._maybe_inject_headers(kwargs) |
| 312 | + return await self._client.abatch(*args, **kwargs) |
| 313 | + |
| 314 | + def apply(self, *args: Any, **kwargs: Any) -> Any: |
| 315 | + self._maybe_inject_headers(kwargs) |
| 316 | + return self._client.apply(*args, **kwargs) |
| 317 | + |
| 318 | + async def aapply( |
| 319 | + self, |
| 320 | + *args: Any, |
| 321 | + **kwargs: Any, |
| 322 | + ) -> Any: |
| 323 | + self._maybe_inject_headers(kwargs) |
| 324 | + return await self._client.aapply(*args, **kwargs) |
| 325 | + |
| 326 | + def __repr__(self) -> str: |
| 327 | + return ( |
| 328 | + 'SingleStoreExperimentalChat(' |
| 329 | + f'identifier={self.model_name!r}, ' |
| 330 | + f'actual_model={self.actual_model!r}, ' |
| 331 | + f'prefix={self.prefix}, backend={self._backend_type})' |
| 332 | + ) |
0 commit comments