1
1
import asyncio
2
2
import json
3
+ import sqlite3
3
4
import uuid
4
5
from pathlib import Path
5
6
from typing import Dict , List , Optional , Type
6
7
8
+ import numpy as np
9
+ import sqlite_vec_sl_tmp
7
10
import structlog
8
11
from alembic import command as alembic_command
9
12
from alembic .config import Config as AlembicConfig
22
25
IntermediatePromptWithOutputUsageAlerts ,
23
26
MuxRule ,
24
27
Output ,
28
+ Persona ,
29
+ PersonaDistance ,
30
+ PersonaEmbedding ,
25
31
Prompt ,
26
32
ProviderAuthMaterial ,
27
33
ProviderEndpoint ,
@@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
65
71
# It should only be used for testing
66
72
if "_no_singleton" in kwargs and kwargs ["_no_singleton" ]:
67
73
kwargs .pop ("_no_singleton" )
68
- return super ().__new__ (cls , * args , ** kwargs )
74
+ return super ().__new__ (cls )
69
75
70
76
if cls ._instance is None :
71
77
cls ._instance = super ().__new__ (cls )
@@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
92
98
}
93
99
self ._async_db_engine = create_async_engine (** engine_dict )
94
100
101
+ def _get_vec_db_connection (self ):
102
+ """
103
+ Vector database connection is a separate connection to the SQLite database. aiosqlite
104
+ does not support loading extensions, so we need to use the sqlite3 module to load the
105
+ vector extension.
106
+ """
107
+ try :
108
+ conn = sqlite3 .connect (self ._db_path )
109
+ conn .enable_load_extension (True )
110
+ sqlite_vec_sl_tmp .load (conn )
111
+ conn .enable_load_extension (False )
112
+ return conn
113
+ except Exception :
114
+ logger .exception ("Failed to initialize vector database connection" )
115
+ raise
116
+
95
117
def does_db_exist (self ):
96
118
return self ._db_path .is_file ()
97
119
@@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
523
545
added_mux = await self ._execute_update_pydantic_model (mux , sql , should_raise = True )
524
546
return added_mux
525
547
548
+ async def add_persona (self , persona : PersonaEmbedding ) -> None :
549
+ """Add a new Persona to the DB.
550
+
551
+ This handles validation and insertion of a new persona.
552
+
553
+ It may raise a AlreadyExistsError if the persona already exists.
554
+ """
555
+ sql = text (
556
+ """
557
+ INSERT INTO personas (id, name, description, description_embedding)
558
+ VALUES (:id, :name, :description, :description_embedding)
559
+ """
560
+ )
561
+
562
+ try :
563
+ # For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
564
+ # We need to convert it back to a numpy array before inserting it into the DB.
565
+ persona_dict = persona .model_dump ()
566
+ persona_dict ["description_embedding" ] = persona .description_embedding
567
+ await self ._execute_with_no_return (sql , persona_dict )
568
+ except IntegrityError as e :
569
+ logger .debug (f"Exception type: { type (e )} " )
570
+ raise AlreadyExistsError (f"Persona '{ persona .name } ' already exists." )
571
+
526
572
527
573
class DbReader (DbCodeGate ):
528
574
def __init__ (self , sqlite_path : Optional [str ] = None , * args , ** kwargs ):
@@ -569,6 +615,20 @@ async def _exec_select_conditions_to_pydantic(
569
615
raise e
570
616
return None
571
617
618
+ async def _exec_vec_db_query_to_pydantic (
619
+ self , sql_command : str , conditions : dict , model_type : Type [BaseModel ]
620
+ ) -> List [BaseModel ]:
621
+ """
622
+ Execute a query on the vector database. This is a separate connection to the SQLite
623
+ database that has the vector extension loaded.
624
+ """
625
+ conn = self ._get_vec_db_connection ()
626
+ conn .row_factory = sqlite3 .Row
627
+ cursor = conn .cursor ()
628
+ results = [model_type (** row ) for row in cursor .execute (sql_command , conditions )]
629
+ conn .close ()
630
+ return results
631
+
572
632
async def get_prompts_with_output (self , workpace_id : str ) -> List [GetPromptWithOutputsRow ]:
573
633
sql = text (
574
634
"""
@@ -893,6 +953,45 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
893
953
)
894
954
return muxes
895
955
956
+ async def get_persona_by_name (self , persona_name : str ) -> Optional [Persona ]:
957
+ """
958
+ Get a persona by name.
959
+ """
960
+ sql = text (
961
+ """
962
+ SELECT
963
+ id, name, description
964
+ FROM personas
965
+ WHERE name = :name
966
+ """
967
+ )
968
+ conditions = {"name" : persona_name }
969
+ personas = await self ._exec_select_conditions_to_pydantic (
970
+ Persona , sql , conditions , should_raise = True
971
+ )
972
+ return personas [0 ] if personas else None
973
+
974
+ async def get_distance_to_persona (
975
+ self , persona_id : str , query_embedding : np .ndarray
976
+ ) -> PersonaDistance :
977
+ """
978
+ Get the distance between a persona and a query embedding.
979
+ """
980
+ sql = """
981
+ SELECT
982
+ id,
983
+ name,
984
+ description,
985
+ vec_distance_cosine(description_embedding, :query_embedding) as distance
986
+ FROM personas
987
+ WHERE id = :id
988
+ """
989
+ conditions = {"id" : persona_id , "query_embedding" : query_embedding }
990
+ persona_distance = await self ._exec_vec_db_query_to_pydantic (
991
+ sql , conditions , PersonaDistance
992
+ )
993
+ return persona_distance [0 ]
994
+
896
995
897
996
def init_db_sync (db_path : Optional [str ] = None ):
898
997
"""DB will be initialized in the constructor in case it doesn't exist."""
0 commit comments