Skip to content

Commit bbb8237

Browse files
Provide option for getting 'X-S2-OBO' token for every request.
1 parent 6ea4a6d commit bbb8237

File tree

1 file changed

+32
-7
lines changed

1 file changed

+32
-7
lines changed

singlestoredb/ai/chat.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import uuid
33
from typing import Any
44
from typing import AsyncIterator
5+
from typing import Callable
56
from typing import Optional
67

78
import httpx
@@ -92,6 +93,7 @@ def __init__(
9293
model_name: str,
9394
http_client: Optional[httpx.Client] = None,
9495
api_key: Optional[str] = None,
96+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
9597
**kwargs: Any,
9698
) -> None:
9799
prefix, actual_model = self._parse_identifier(model_name)
@@ -192,6 +194,10 @@ def __init__(
192194
self.actual_model = actual_model # model portion after prefix
193195
self.prefix = prefix # normalized prefix
194196
self.connection_url = info.connection_url
197+
# Optional callable returning a fresh OBO token each request (Bedrock only).
198+
# If supplied, a new token will be fetched and injected into the
199+
# 'X-S2-OBO' header for every Bedrock request made via this wrapper.
200+
self._obo_token_getter = obo_token_getter
195201

196202
@classmethod
197203
def _parse_identifier(cls, identifier: str) -> tuple[str, str]:
@@ -223,16 +229,35 @@ def _maybe_inject_headers(self, kwargs: dict[str, Any]) -> None:
223229
call that allows a `headers` kwarg and has not already provided
224230
its own.
225231
"""
226-
if (
227-
self._backend_type == 'bedrock'
228-
and hasattr(self, '_auth_header')
232+
if self._backend_type != 'bedrock':
233+
return
234+
235+
# Start from existing headers in the call.
236+
# Copy to avoid mutating caller-provided dict in-place.
237+
call_headers: dict[str, str] = {}
238+
if 'headers' in kwargs and isinstance(kwargs['headers'], dict):
239+
call_headers = dict(kwargs['headers'])
240+
elif (
241+
hasattr(self, '_auth_header')
229242
and getattr(self, '_auth_header')
230243
and 'headers' not in kwargs
231244
):
232-
kwargs['headers'] = getattr(
233-
self,
234-
'_auth_header',
235-
)
245+
# Use fallback auth header if user did not pass any.
246+
call_headers = dict(getattr(self, '_auth_header'))
247+
248+
# Dynamic OBO token injection (always fresh per request if getter provided)
249+
getter = getattr(self, '_obo_token_getter', None)
250+
if getter is not None:
251+
try:
252+
obo_token = getter()
253+
except Exception:
254+
obo_token = None
255+
if obo_token:
256+
# Overwrite any stale value.
257+
call_headers['X-S2-OBO'] = obo_token
258+
259+
if call_headers:
260+
kwargs['headers'] = call_headers
236261

237262
def as_base(self) -> Any:
238263
"""Return the underlying backend client instance.

0 commit comments

Comments
 (0)