|
2 | 2 | import uuid |
3 | 3 | from typing import Any |
4 | 4 | from typing import AsyncIterator |
| 5 | +from typing import Callable |
5 | 6 | from typing import Optional |
6 | 7 |
|
7 | 8 | import httpx |
@@ -92,6 +93,7 @@ def __init__( |
92 | 93 | model_name: str, |
93 | 94 | http_client: Optional[httpx.Client] = None, |
94 | 95 | api_key: Optional[str] = None, |
| 96 | + obo_token_getter: Optional[Callable[[], Optional[str]]] = None, |
95 | 97 | **kwargs: Any, |
96 | 98 | ) -> None: |
97 | 99 | prefix, actual_model = self._parse_identifier(model_name) |
@@ -192,6 +194,10 @@ def __init__( |
192 | 194 | self.actual_model = actual_model # model portion after prefix |
193 | 195 | self.prefix = prefix # normalized prefix |
194 | 196 | self.connection_url = info.connection_url |
| 197 | + # Optional callable returning a fresh OBO token each request (Bedrock only). |
| 198 | + # If supplied, a new token will be fetched and injected into the |
| 199 | + # 'X-S2-OBO' header for every Bedrock request made via this wrapper. |
| 200 | + self._obo_token_getter = obo_token_getter |
195 | 201 |
|
196 | 202 | @classmethod |
197 | 203 | def _parse_identifier(cls, identifier: str) -> tuple[str, str]: |
@@ -223,16 +229,35 @@ def _maybe_inject_headers(self, kwargs: dict[str, Any]) -> None: |
223 | 229 | call that allows a `headers` kwarg and has not already provided |
224 | 230 | its own. |
225 | 231 | """ |
226 | | - if ( |
227 | | - self._backend_type == 'bedrock' |
228 | | - and hasattr(self, '_auth_header') |
| 232 | + if self._backend_type != 'bedrock': |
| 233 | + return |
| 234 | + |
| 235 | + # Start from existing headers in the call. |
| 236 | + # Copy to avoid mutating caller-provided dict in-place. |
| 237 | + call_headers: dict[str, str] = {} |
| 238 | + if 'headers' in kwargs and isinstance(kwargs['headers'], dict): |
| 239 | + call_headers = dict(kwargs['headers']) |
| 240 | + elif ( |
| 241 | + hasattr(self, '_auth_header') |
229 | 242 | and getattr(self, '_auth_header') |
230 | 243 | and 'headers' not in kwargs |
231 | 244 | ): |
232 | | - kwargs['headers'] = getattr( |
233 | | - self, |
234 | | - '_auth_header', |
235 | | - ) |
| 245 | + # Use fallback auth header if user did not pass any. |
| 246 | + call_headers = dict(getattr(self, '_auth_header')) |
| 247 | + |
| 248 | + # Dynamic OBO token injection (always fresh per request if getter provided) |
| 249 | + getter = getattr(self, '_obo_token_getter', None) |
| 250 | + if getter is not None: |
| 251 | + try: |
| 252 | + obo_token = getter() |
| 253 | + except Exception: |
| 254 | + obo_token = None |
| 255 | + if obo_token: |
| 256 | + # Overwrite any stale value. |
| 257 | + call_headers['X-S2-OBO'] = obo_token |
| 258 | + |
| 259 | + if call_headers: |
| 260 | + kwargs['headers'] = call_headers |
236 | 261 |
|
237 | 262 | def as_base(self) -> Any: |
238 | 263 | """Return the underlying backend client instance. |
|
0 commit comments