diff --git a/singlestoredb/ai/chat.py b/singlestoredb/ai/chat.py index 88481760..cc81ef18 100644 --- a/singlestoredb/ai/chat.py +++ b/singlestoredb/ai/chat.py @@ -90,9 +90,137 @@ def SingleStoreChatFactory( 'signature_version': UNSIGNED, 'retries': {'max_attempts': 1, 'mode': 'standard'}, } - if http_client is not None and http_client.timeout is not None: - cfg_kwargs['read_timeout'] = http_client.timeout - cfg_kwargs['connect_timeout'] = http_client.timeout + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if read_timeout is not None: + cfg_kwargs['read_timeout'] = read_timeout + if connect_timeout is not None: + cfg_kwargs['connect_timeout'] = connect_timeout + + cfg = Config(**cfg_kwargs) + client = boto3.client( + 'bedrock-runtime', + endpoint_url=info.connection_url, + region_name='us-east-1', + aws_access_key_id='placeholder', + aws_secret_access_key='placeholder', + config=cfg, + ) + + def _inject_headers(request: Any, **_ignored: Any) -> None: + """Inject dynamic auth/OBO headers prior to Bedrock sending.""" + if obo_token_getter is not None: + obo_val = obo_token_getter() + if obo_val: + request.headers['X-S2-OBO'] = obo_val + if token: + request.headers['Authorization'] = f'Bearer {token}' + request.headers.pop('X-Amz-Date', None) + request.headers.pop('X-Amz-Security-Token', None) + + emitter = client._endpoint._event_emitter + emitter.register_first( + 'before-send.bedrock-runtime.Converse', + _inject_headers, + ) + emitter.register_first( + 'before-send.bedrock-runtime.ConverseStream', + _inject_headers, + ) + emitter.register_first( + 'before-send.bedrock-runtime.InvokeModel', + _inject_headers, + ) + emitter.register_first( + 'before-send.bedrock-runtime.InvokeModelWithResponseStream', + _inject_headers, + ) + + return ChatBedrockConverse( + model_id=model_name, + endpoint_url=info.connection_url, + region_name='us-east-1', + aws_access_key_id='placeholder', + aws_secret_access_key='placeholder', + disable_streaming=not streaming, + client=client, + **kwargs, + ) + + # OpenAI / Azure OpenAI path + openai_kwargs = dict( + base_url=info.connection_url, + api_key=token, + model=model_name, + streaming=streaming, + ) + if http_client is not None: + openai_kwargs['http_client'] = http_client + return ChatOpenAI( + **openai_kwargs, + **kwargs, + ) + + +def NewSingleStoreChatFactory( + model_name: str, + api_key: Optional[str] = None, + streaming: bool = True, + http_client: Optional[httpx.Client] = None, + obo_token_getter: Optional[Callable[[], Optional[str]]] = None, + **kwargs: Any, +) -> Union[ChatOpenAI, ChatBedrockConverse]: + """Return a chat model instance (ChatOpenAI or ChatBedrockConverse). + """ + inference_api_manager = ( + manage_workspaces().organizations.current.inference_apis + ) + info = inference_api_manager.get(model_name=model_name) + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token = api_key if api_key is not None else token_env + + if info.hosting_platform == 'Amazon': + # Instantiate Bedrock client + cfg_kwargs = { + 'signature_version': UNSIGNED, + 'retries': {'max_attempts': 1, 'mode': 'standard'}, + } + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if read_timeout is not None: + cfg_kwargs['read_timeout'] = read_timeout + if connect_timeout is not None: + cfg_kwargs['connect_timeout'] = connect_timeout cfg = Config(**cfg_kwargs) client = boto3.client( diff --git a/singlestoredb/ai/embeddings.py b/singlestoredb/ai/embeddings.py index bd6c81ef..497a01bd 100644 --- a/singlestoredb/ai/embeddings.py +++ b/singlestoredb/ai/embeddings.py @@ -66,9 +66,126 @@ def SingleStoreEmbeddingsFactory( 'signature_version': UNSIGNED, 'retries': {'max_attempts': 1, 'mode': 'standard'}, } - if http_client is not None and http_client.timeout is not None: - cfg_kwargs['read_timeout'] = http_client.timeout - cfg_kwargs['connect_timeout'] = http_client.timeout + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if read_timeout is not None: + cfg_kwargs['read_timeout'] = read_timeout + if connect_timeout is not None: + cfg_kwargs['connect_timeout'] = connect_timeout + + cfg = Config(**cfg_kwargs) + client = boto3.client( + 'bedrock-runtime', + endpoint_url=info.connection_url, + region_name='us-east-1', + aws_access_key_id='placeholder', + aws_secret_access_key='placeholder', + config=cfg, + ) + + def _inject_headers(request: Any, **_ignored: Any) -> None: + """Inject dynamic auth/OBO headers prior to Bedrock sending.""" + if obo_token_getter is not None: + obo_val = obo_token_getter() + if obo_val: + request.headers['X-S2-OBO'] = obo_val + if token: + request.headers['Authorization'] = f'Bearer {token}' + request.headers.pop('X-Amz-Date', None) + request.headers.pop('X-Amz-Security-Token', None) + + emitter = client._endpoint._event_emitter + emitter.register_first( + 'before-send.bedrock-runtime.InvokeModel', + _inject_headers, + ) + emitter.register_first( + 'before-send.bedrock-runtime.InvokeModelWithResponseStream', + _inject_headers, + ) + + return BedrockEmbeddings( + model_id=model_name, + endpoint_url=info.connection_url, + region_name='us-east-1', + aws_access_key_id='placeholder', + aws_secret_access_key='placeholder', + client=client, + **kwargs, + ) + + # OpenAI / Azure OpenAI path + openai_kwargs = dict( + base_url=info.connection_url, + api_key=token, + model=model_name, + ) + if http_client is not None: + openai_kwargs['http_client'] = http_client + return OpenAIEmbeddings( + **openai_kwargs, + **kwargs, + ) + + +def NewSingleStoreEmbeddingsFactory( + model_name: str, + api_key: Optional[str] = None, + http_client: Optional[httpx.Client] = None, + obo_token_getter: Optional[Callable[[], Optional[str]]] = None, + **kwargs: Any, +) -> Union[OpenAIEmbeddings, BedrockEmbeddings]: + """Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings). + """ + inference_api_manager = ( + manage_workspaces().organizations.current.inference_apis + ) + info = inference_api_manager.get(model_name=model_name) + token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN') + token = api_key if api_key is not None else token_env + + if info.hosting_platform == 'Amazon': + # Instantiate Bedrock client + cfg_kwargs = { + 'signature_version': UNSIGNED, + 'retries': {'max_attempts': 1, 'mode': 'standard'}, + } + # Extract timeouts from http_client if provided + t = http_client.timeout if http_client is not None else None + connect_timeout = None + read_timeout = None + if t is not None: + if isinstance(t, httpx.Timeout): + if t.connect is not None: + connect_timeout = float(t.connect) + if t.read is not None: + read_timeout = float(t.read) + if connect_timeout is None and read_timeout is not None: + connect_timeout = read_timeout + if read_timeout is None and connect_timeout is not None: + read_timeout = connect_timeout + elif isinstance(t, (int, float)): + connect_timeout = float(t) + read_timeout = float(t) + if read_timeout is not None: + cfg_kwargs['read_timeout'] = read_timeout + if connect_timeout is not None: + cfg_kwargs['connect_timeout'] = connect_timeout cfg = Config(**cfg_kwargs) client = boto3.client(