Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 258fc98

Browse files
Created necessary methods for Persona CRUD (#1232)
* Created necessary methods for Persona CRUD Closes: #1219 Some changes in this PR - Renamed Semantic Router to PersonaManager: The motivation is that the only semantic routing we're doing is based on persona. So lets just call it that wat - Created update and delete methods for Persona - Added tests for the whole Persona CRUD * linting issues
1 parent da69ec0 commit 258fc98

File tree

6 files changed

+316
-27
lines changed

6 files changed

+316
-27
lines changed

src/codegate/api/v1.py

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from codegate import __version__
1313
from codegate.api import v1_models, v1_processing
1414
from codegate.db.connection import AlreadyExistsError, DbReader
15-
from codegate.db.models import AlertSeverity, WorkspaceWithModel
15+
from codegate.db.models import AlertSeverity, Persona, WorkspaceWithModel
16+
from codegate.muxing.persona import (
17+
PersonaDoesNotExistError,
18+
PersonaManager,
19+
PersonaSimilarDescriptionError,
20+
)
1621
from codegate.providers import crud as provendcrud
1722
from codegate.workspaces import crud
1823

@@ -21,6 +26,7 @@
2126
v1 = APIRouter()
2227
wscrud = crud.WorkspaceCrud()
2328
pcrud = provendcrud.ProviderCrud()
29+
persona_manager = PersonaManager()
2430

2531
# This is a singleton object
2632
dbreader = DbReader()
@@ -665,3 +671,89 @@ async def get_workspace_token_usage(workspace_name: str) -> v1_models.TokenUsage
665671
except Exception:
666672
logger.exception("Error while getting messages")
667673
raise HTTPException(status_code=500, detail="Internal server error")
674+
675+
676+
@v1.get("/personas", tags=["Personas"], generate_unique_id_function=uniq_name)
677+
async def list_personas() -> List[Persona]:
678+
"""List all personas."""
679+
try:
680+
personas = await dbreader.get_all_personas()
681+
return personas
682+
except Exception:
683+
logger.exception("Error while getting personas")
684+
raise HTTPException(status_code=500, detail="Internal server error")
685+
686+
687+
@v1.get("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
688+
async def get_persona(persona_name: str) -> Persona:
689+
"""Get a persona by name."""
690+
try:
691+
persona = await dbreader.get_persona_by_name(persona_name)
692+
if not persona:
693+
raise HTTPException(status_code=404, detail=f"Persona {persona_name} not found")
694+
return persona
695+
except Exception as e:
696+
if isinstance(e, HTTPException):
697+
raise e
698+
logger.exception(f"Error while getting persona {persona_name}")
699+
raise HTTPException(status_code=500, detail="Internal server error")
700+
701+
702+
@v1.post("/personas", tags=["Personas"], generate_unique_id_function=uniq_name, status_code=201)
703+
async def create_persona(request: v1_models.PersonaRequest) -> Persona:
704+
"""Create a new persona."""
705+
try:
706+
await persona_manager.add_persona(request.name, request.description)
707+
persona = await dbreader.get_persona_by_name(request.name)
708+
return persona
709+
except PersonaSimilarDescriptionError:
710+
logger.exception("Error while creating persona")
711+
raise HTTPException(status_code=409, detail="Persona has a similar description to another")
712+
except AlreadyExistsError:
713+
logger.exception("Error while creating persona")
714+
raise HTTPException(status_code=409, detail="Persona already exists")
715+
except Exception:
716+
logger.exception("Error while creating persona")
717+
raise HTTPException(status_code=500, detail="Internal server error")
718+
719+
720+
@v1.put("/personas/{persona_name}", tags=["Personas"], generate_unique_id_function=uniq_name)
721+
async def update_persona(persona_name: str, request: v1_models.PersonaUpdateRequest) -> Persona:
722+
"""Update an existing persona."""
723+
try:
724+
await persona_manager.update_persona(
725+
persona_name, request.new_name, request.new_description
726+
)
727+
persona = await dbreader.get_persona_by_name(request.new_name)
728+
return persona
729+
except PersonaSimilarDescriptionError:
730+
logger.exception("Error while updating persona")
731+
raise HTTPException(status_code=409, detail="Persona has a similar description to another")
732+
except PersonaDoesNotExistError:
733+
logger.exception("Error while updating persona")
734+
raise HTTPException(status_code=404, detail="Persona does not exist")
735+
except AlreadyExistsError:
736+
logger.exception("Error while updating persona")
737+
raise HTTPException(status_code=409, detail="Persona already exists")
738+
except Exception:
739+
logger.exception("Error while updating persona")
740+
raise HTTPException(status_code=500, detail="Internal server error")
741+
742+
743+
@v1.delete(
744+
"/personas/{persona_name}",
745+
tags=["Personas"],
746+
generate_unique_id_function=uniq_name,
747+
status_code=204,
748+
)
749+
async def delete_persona(persona_name: str):
750+
"""Delete a persona."""
751+
try:
752+
await persona_manager.delete_persona(persona_name)
753+
return Response(status_code=204)
754+
except PersonaDoesNotExistError:
755+
logger.exception("Error while updating persona")
756+
raise HTTPException(status_code=404, detail="Persona does not exist")
757+
except Exception:
758+
logger.exception("Error while deleting persona")
759+
raise HTTPException(status_code=500, detail="Internal server error")

src/codegate/api/v1_models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,3 +315,21 @@ class ModelByProvider(pydantic.BaseModel):
315315

316316
def __str__(self):
317317
return f"{self.provider_name} / {self.name}"
318+
319+
320+
class PersonaRequest(pydantic.BaseModel):
321+
"""
322+
Model for creating a new Persona.
323+
"""
324+
325+
name: str
326+
description: str
327+
328+
329+
class PersonaUpdateRequest(pydantic.BaseModel):
330+
"""
331+
Model for updating a Persona.
332+
"""
333+
334+
new_name: str
335+
new_description: str

src/codegate/db/connection.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -561,15 +561,41 @@ async def add_persona(self, persona: PersonaEmbedding) -> None:
561561
)
562562

563563
try:
564-
# For Pydantic we convert the numpy array to string when serializing with .model_dumpy()
565-
# We need to convert it back to a numpy array before inserting it into the DB.
566-
persona_dict = persona.model_dump()
567-
persona_dict["description_embedding"] = persona.description_embedding
568-
await self._execute_with_no_return(sql, persona_dict)
564+
await self._execute_with_no_return(sql, persona.model_dump())
569565
except IntegrityError as e:
570566
logger.debug(f"Exception type: {type(e)}")
571567
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
572568

569+
async def update_persona(self, persona: PersonaEmbedding) -> None:
570+
"""
571+
Update an existing Persona in the DB.
572+
573+
This handles validation and update of an existing persona.
574+
"""
575+
sql = text(
576+
"""
577+
UPDATE personas
578+
SET name = :name,
579+
description = :description,
580+
description_embedding = :description_embedding
581+
WHERE id = :id
582+
"""
583+
)
584+
585+
try:
586+
await self._execute_with_no_return(sql, persona.model_dump())
587+
except IntegrityError as e:
588+
logger.debug(f"Exception type: {type(e)}")
589+
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")
590+
591+
async def delete_persona(self, persona_id: str) -> None:
592+
"""
593+
Delete an existing Persona from the DB.
594+
"""
595+
sql = text("DELETE FROM personas WHERE id = :id")
596+
conditions = {"id": persona_id}
597+
await self._execute_with_no_return(sql, conditions)
598+
573599

574600
class DbReader(DbCodeGate):
575601
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
@@ -588,14 +614,20 @@ async def _dump_result_to_pydantic_model(
588614
return None
589615

590616
async def _execute_select_pydantic_model(
591-
self, model_type: Type[BaseModel], sql_command: TextClause
617+
self,
618+
model_type: Type[BaseModel],
619+
sql_command: TextClause,
620+
should_raise: bool = False,
592621
) -> Optional[List[BaseModel]]:
593622
async with self._async_db_engine.begin() as conn:
594623
try:
595624
result = await conn.execute(sql_command)
596625
return await self._dump_result_to_pydantic_model(model_type, result)
597626
except Exception as e:
598627
logger.error(f"Failed to select model: {model_type}.", error=str(e))
628+
# Exposes errors to the caller
629+
if should_raise:
630+
raise e
599631
return None
600632

601633
async def _exec_select_conditions_to_pydantic(
@@ -1005,7 +1037,7 @@ async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
10051037
return personas[0] if personas else None
10061038

10071039
async def get_distance_to_existing_personas(
1008-
self, query_embedding: np.ndarray
1040+
self, query_embedding: np.ndarray, exclude_id: Optional[str]
10091041
) -> List[PersonaDistance]:
10101042
"""
10111043
Get the distance between a persona and a query embedding.
@@ -1019,6 +1051,13 @@ async def get_distance_to_existing_personas(
10191051
FROM personas
10201052
"""
10211053
conditions = {"query_embedding": query_embedding}
1054+
1055+
# Exclude this persona from the SQL query. Used when checking the descriptions
1056+
# for updating the persona. Exclude the persona to update itself from the query.
1057+
if exclude_id:
1058+
sql += " WHERE id != :exclude_id"
1059+
conditions["exclude_id"] = exclude_id
1060+
10221061
persona_distances = await self._exec_vec_db_query_to_pydantic(
10231062
sql, conditions, PersonaDistance
10241063
)
@@ -1045,6 +1084,20 @@ async def get_distance_to_persona(
10451084
)
10461085
return persona_distance[0]
10471086

1087+
async def get_all_personas(self) -> List[Persona]:
1088+
"""
1089+
Get all the personas.
1090+
"""
1091+
sql = text(
1092+
"""
1093+
SELECT
1094+
id, name, description
1095+
FROM personas
1096+
"""
1097+
)
1098+
personas = await self._execute_select_pydantic_model(Persona, sql, should_raise=True)
1099+
return personas
1100+
10481101

10491102
class DbTransaction:
10501103
def __init__(self):

src/codegate/db/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def nd_array_custom_before_validator(x):
252252

253253
def nd_array_custom_serializer(x):
254254
# custome serialization logic
255-
return str(x)
255+
return x
256256

257257

258258
# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.

src/codegate/muxing/semantic_router.py renamed to src/codegate/muxing/persona.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unicodedata
22
import uuid
3+
from typing import Optional
34

45
import numpy as np
56
import regex as re
@@ -32,11 +33,12 @@ class PersonaSimilarDescriptionError(Exception):
3233
pass
3334

3435

35-
class SemanticRouter:
36+
class PersonaManager:
3637

3738
def __init__(self):
38-
self._inference_engine = LlamaCppInferenceEngine()
39+
Config.load()
3940
conf = Config.get_config()
41+
self._inference_engine = LlamaCppInferenceEngine()
4042
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
4143
self._n_gpu = conf.chat_model_n_gpu_layers
4244
self._persona_threshold = conf.persona_threshold
@@ -110,13 +112,15 @@ async def _embed_text(self, text: str) -> np.ndarray:
110112
logger.debug("Text embedded in semantic routing", text=cleaned_text[:50])
111113
return np.array(embed_list[0], dtype=np.float32)
112114

113-
async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bool:
115+
async def _is_persona_description_diff(
116+
self, emb_persona_desc: np.ndarray, exclude_id: Optional[str]
117+
) -> bool:
114118
"""
115119
Check if the persona description is different enough from existing personas.
116120
"""
117121
# The distance calculation is done in the database
118122
persona_distances = await self._db_reader.get_distance_to_existing_personas(
119-
emb_persona_desc
123+
emb_persona_desc, exclude_id
120124
)
121125
if not persona_distances:
122126
return True
@@ -131,16 +135,26 @@ async def _is_persona_description_diff(self, emb_persona_desc: np.ndarray) -> bo
131135
return False
132136
return True
133137

134-
async def add_persona(self, persona_name: str, persona_desc: str) -> None:
138+
async def _validate_persona_description(
139+
self, persona_desc: str, exclude_id: str = None
140+
) -> np.ndarray:
135141
"""
136-
Add a new persona to the database. The persona description is embedded
137-
and stored in the database.
142+
Validate the persona description by embedding the text and checking if it is
143+
different enough from existing personas.
138144
"""
139145
emb_persona_desc = await self._embed_text(persona_desc)
140-
if not await self._is_persona_description_diff(emb_persona_desc):
146+
if not await self._is_persona_description_diff(emb_persona_desc, exclude_id):
141147
raise PersonaSimilarDescriptionError(
142148
"The persona description is too similar to existing personas."
143149
)
150+
return emb_persona_desc
151+
152+
async def add_persona(self, persona_name: str, persona_desc: str) -> None:
153+
"""
154+
Add a new persona to the database. The persona description is embedded
155+
and stored in the database.
156+
"""
157+
emb_persona_desc = await self._validate_persona_description(persona_desc)
144158

145159
new_persona = db_models.PersonaEmbedding(
146160
id=str(uuid.uuid4()),
@@ -151,6 +165,43 @@ async def add_persona(self, persona_name: str, persona_desc: str) -> None:
151165
await self._db_recorder.add_persona(new_persona)
152166
logger.info(f"Added persona {persona_name} to the database.")
153167

168+
async def update_persona(
169+
self, persona_name: str, new_persona_name: str, new_persona_desc: str
170+
) -> None:
171+
"""
172+
Update an existing persona in the database. The name and description are
173+
updated in the database, but the ID remains the same.
174+
"""
175+
# First we check if the persona exists, if not we raise an error
176+
found_persona = await self._db_reader.get_persona_by_name(persona_name)
177+
if not found_persona:
178+
raise PersonaDoesNotExistError(f"Person {persona_name} does not exist.")
179+
180+
emb_persona_desc = await self._validate_persona_description(
181+
new_persona_desc, exclude_id=found_persona.id
182+
)
183+
184+
# Then we update the attributes in the database
185+
updated_persona = db_models.PersonaEmbedding(
186+
id=found_persona.id,
187+
name=new_persona_name,
188+
description=new_persona_desc,
189+
description_embedding=emb_persona_desc,
190+
)
191+
await self._db_recorder.update_persona(updated_persona)
192+
logger.info(f"Updated persona {persona_name} in the database.")
193+
194+
async def delete_persona(self, persona_name: str) -> None:
195+
"""
196+
Delete a persona from the database.
197+
"""
198+
persona = await self._db_reader.get_persona_by_name(persona_name)
199+
if not persona:
200+
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")
201+
202+
await self._db_recorder.delete_persona(persona.id)
203+
logger.info(f"Deleted persona {persona_name} from the database.")
204+
154205
async def check_persona_match(self, persona_name: str, query: str) -> bool:
155206
"""
156207
Check if the query matches the persona description. A vector similarity

0 commit comments

Comments
 (0)