Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 131 additions & 3 deletions singlestoredb/ai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,137 @@ def SingleStoreChatFactory(
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
if http_client is not None and http_client.timeout is not None:
cfg_kwargs['read_timeout'] = http_client.timeout
cfg_kwargs['connect_timeout'] = http_client.timeout
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
cfg_kwargs['connect_timeout'] = connect_timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
'bedrock-runtime',
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
config=cfg,
)

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

emitter = client._endpoint._event_emitter
emitter.register_first(
'before-send.bedrock-runtime.Converse',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.ConverseStream',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModel',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
_inject_headers,
)

return ChatBedrockConverse(
model_id=model_name,
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
disable_streaming=not streaming,
client=client,
**kwargs,
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
model=model_name,
streaming=streaming,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
return ChatOpenAI(
**openai_kwargs,
**kwargs,
)


def NewSingleStoreChatFactory(
model_name: str,
api_key: Optional[str] = None,
streaming: bool = True,
http_client: Optional[httpx.Client] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
**kwargs: Any,
) -> Union[ChatOpenAI, ChatBedrockConverse]:
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
"""
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
cfg_kwargs['connect_timeout'] = connect_timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
Expand Down
123 changes: 120 additions & 3 deletions singlestoredb/ai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,126 @@ def SingleStoreEmbeddingsFactory(
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
if http_client is not None and http_client.timeout is not None:
cfg_kwargs['read_timeout'] = http_client.timeout
cfg_kwargs['connect_timeout'] = http_client.timeout
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
cfg_kwargs['connect_timeout'] = connect_timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
'bedrock-runtime',
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
config=cfg,
)

def _inject_headers(request: Any, **_ignored: Any) -> None:
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
if obo_token_getter is not None:
obo_val = obo_token_getter()
if obo_val:
request.headers['X-S2-OBO'] = obo_val
if token:
request.headers['Authorization'] = f'Bearer {token}'
request.headers.pop('X-Amz-Date', None)
request.headers.pop('X-Amz-Security-Token', None)

emitter = client._endpoint._event_emitter
emitter.register_first(
'before-send.bedrock-runtime.InvokeModel',
_inject_headers,
)
emitter.register_first(
'before-send.bedrock-runtime.InvokeModelWithResponseStream',
_inject_headers,
)

return BedrockEmbeddings(
model_id=model_name,
endpoint_url=info.connection_url,
region_name='us-east-1',
aws_access_key_id='placeholder',
aws_secret_access_key='placeholder',
client=client,
**kwargs,
)

# OpenAI / Azure OpenAI path
openai_kwargs = dict(
base_url=info.connection_url,
api_key=token,
model=model_name,
)
if http_client is not None:
openai_kwargs['http_client'] = http_client
return OpenAIEmbeddings(
**openai_kwargs,
**kwargs,
)


def NewSingleStoreEmbeddingsFactory(
model_name: str,
api_key: Optional[str] = None,
http_client: Optional[httpx.Client] = None,
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
**kwargs: Any,
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
"""
inference_api_manager = (
manage_workspaces().organizations.current.inference_apis
)
info = inference_api_manager.get(model_name=model_name)
token_env = os.environ.get('SINGLESTOREDB_USER_TOKEN')
token = api_key if api_key is not None else token_env

if info.hosting_platform == 'Amazon':
# Instantiate Bedrock client
cfg_kwargs = {
'signature_version': UNSIGNED,
'retries': {'max_attempts': 1, 'mode': 'standard'},
}
# Extract timeouts from http_client if provided
t = http_client.timeout if http_client is not None else None
connect_timeout = None
read_timeout = None
if t is not None:
if isinstance(t, httpx.Timeout):
if t.connect is not None:
connect_timeout = float(t.connect)
if t.read is not None:
read_timeout = float(t.read)
if connect_timeout is None and read_timeout is not None:
connect_timeout = read_timeout
if read_timeout is None and connect_timeout is not None:
read_timeout = connect_timeout
elif isinstance(t, (int, float)):
connect_timeout = float(t)
read_timeout = float(t)
if read_timeout is not None:
cfg_kwargs['read_timeout'] = read_timeout
if connect_timeout is not None:
cfg_kwargs['connect_timeout'] = connect_timeout

cfg = Config(**cfg_kwargs)
client = boto3.client(
Expand Down