@@ -175,3 +175,113 @@ def _inject_headers(request: Any, **_ignored: Any) -> None:
175175 ** openai_kwargs ,
176176 ** kwargs ,
177177 )
178+
179+
180+ def NewSingleStoreChatFactory (
181+ model_name : str ,
182+ api_key : Optional [str ] = None ,
183+ streaming : bool = True ,
184+ http_client : Optional [httpx .Client ] = None ,
185+ obo_token_getter : Optional [Callable [[], Optional [str ]]] = None ,
186+ ** kwargs : Any ,
187+ ) -> Union [ChatOpenAI , ChatBedrockConverse ]:
188+ """Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
189+ """
190+ inference_api_manager = (
191+ manage_workspaces ().organizations .current .inference_apis
192+ )
193+ info = inference_api_manager .get (model_name = model_name )
194+ token_env = os .environ .get ('SINGLESTOREDB_USER_TOKEN' )
195+ token = api_key if api_key is not None else token_env
196+
197+ if info .hosting_platform == 'Amazon' :
198+ # Instantiate Bedrock client
199+ cfg_kwargs = {
200+ 'signature_version' : UNSIGNED ,
201+ 'retries' : {'max_attempts' : 1 , 'mode' : 'standard' },
202+ }
203+ # Extract timeouts from http_client if provided
204+ t = http_client .timeout if http_client is not None else None
205+ connect_timeout = None
206+ read_timeout = None
207+ if t is not None :
208+ if isinstance (t , httpx .Timeout ):
209+ if t .connect is not None :
210+ connect_timeout = float (t .connect )
211+ if t .read is not None :
212+ read_timeout = float (t .read )
213+ if connect_timeout is None and read_timeout is not None :
214+ connect_timeout = read_timeout
215+ if read_timeout is None and connect_timeout is not None :
216+ read_timeout = connect_timeout
217+ elif isinstance (t , (int , float )):
218+ connect_timeout = float (t )
219+ read_timeout = float (t )
220+ if read_timeout is not None :
221+ cfg_kwargs ['read_timeout' ] = read_timeout
222+ if connect_timeout is not None :
223+ cfg_kwargs ['connect_timeout' ] = connect_timeout
224+
225+ cfg = Config (** cfg_kwargs )
226+ client = boto3 .client (
227+ 'bedrock-runtime' ,
228+ endpoint_url = info .connection_url ,
229+ region_name = 'us-east-1' ,
230+ aws_access_key_id = 'placeholder' ,
231+ aws_secret_access_key = 'placeholder' ,
232+ config = cfg ,
233+ )
234+
235+ def _inject_headers (request : Any , ** _ignored : Any ) -> None :
236+ """Inject dynamic auth/OBO headers prior to Bedrock sending."""
237+ if obo_token_getter is not None :
238+ obo_val = obo_token_getter ()
239+ if obo_val :
240+ request .headers ['X-S2-OBO' ] = obo_val
241+ if token :
242+ request .headers ['Authorization' ] = f'Bearer { token } '
243+ request .headers .pop ('X-Amz-Date' , None )
244+ request .headers .pop ('X-Amz-Security-Token' , None )
245+
246+ emitter = client ._endpoint ._event_emitter
247+ emitter .register_first (
248+ 'before-send.bedrock-runtime.Converse' ,
249+ _inject_headers ,
250+ )
251+ emitter .register_first (
252+ 'before-send.bedrock-runtime.ConverseStream' ,
253+ _inject_headers ,
254+ )
255+ emitter .register_first (
256+ 'before-send.bedrock-runtime.InvokeModel' ,
257+ _inject_headers ,
258+ )
259+ emitter .register_first (
260+ 'before-send.bedrock-runtime.InvokeModelWithResponseStream' ,
261+ _inject_headers ,
262+ )
263+
264+ return ChatBedrockConverse (
265+ model_id = model_name ,
266+ endpoint_url = info .connection_url ,
267+ region_name = 'us-east-1' ,
268+ aws_access_key_id = 'placeholder' ,
269+ aws_secret_access_key = 'placeholder' ,
270+ disable_streaming = not streaming ,
271+ client = client ,
272+ ** kwargs ,
273+ )
274+
275+ # OpenAI / Azure OpenAI path
276+ openai_kwargs = dict (
277+ base_url = info .connection_url ,
278+ api_key = token ,
279+ model = model_name ,
280+ streaming = streaming ,
281+ )
282+ if http_client is not None :
283+ openai_kwargs ['http_client' ] = http_client
284+ return ChatOpenAI (
285+ ** openai_kwargs ,
286+ ** kwargs ,
287+ )
0 commit comments