diff --git a/singlestoredb/ai/chat.py b/singlestoredb/ai/chat.py index 100bcb454..c637d390d 100644 --- a/singlestoredb/ai/chat.py +++ b/singlestoredb/ai/chat.py @@ -24,3 +24,17 @@ def __init__(self, model_name: str, **kwargs: Any): model=model_name, **kwargs, ) + + +class SingleStoreChat(ChatOpenAI): + def __init__(self, model_name: str, **kwargs: Any): + inference_api_manger = ( + get_workspace_manager().organizations.current.inference_apis + ) + info = inference_api_manger.get(model_name=model_name) + super().__init__( + base_url=info.connection_url, + api_key=os.environ.get('SINGLESTOREDB_USER_TOKEN'), + model=model_name, + **kwargs, + )