3030alert_queue = asyncio .Queue ()
3131fim_cache = FimCache ()
3232
33+
3334class AlreadyExistsError (Exception ):
3435 pass
3536
37+
3638class DbCodeGate :
3739 _instance = None
3840
@@ -246,16 +248,15 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
246248 except Exception as e :
247249 logger .error (f"Failed to record context: { context } ." , error = str (e ))
248250
249- async def add_workspace (self , workspace_name : str ) -> Optional [ Workspace ] :
251+ async def add_workspace (self , workspace_name : str ) -> Workspace :
250252 """Add a new workspace to the DB.
251253
252254 This handles validation and insertion of a new workspace.
253255
254256 It may raise a ValidationError if the workspace name is invalid.
255257 or a AlreadyExistsError if the workspace already exists.
256258 """
257- workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
258-
259+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name , system_prompt = None )
259260 sql = text (
260261 """
261262 INSERT INTO workspaces (id, name)
@@ -266,12 +267,28 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
266267
267268 try :
268269 added_workspace = await self ._execute_update_pydantic_model (
269- workspace , sql , should_raise = True )
270+ workspace , sql , should_raise = True
271+ )
270272 except IntegrityError as e :
271273 logger .debug (f"Exception type: { type (e )} " )
272274 raise AlreadyExistsError (f"Workspace { workspace_name } already exists." )
273275 return added_workspace
274276
277+ async def update_workspace (self , workspace : Workspace ) -> Workspace :
278+ sql = text (
279+ """
280+ UPDATE workspaces SET
281+ name = :name,
282+ system_prompt = :system_prompt
283+ WHERE id = :id
284+ RETURNING *
285+ """
286+ )
287+ updated_workspace = await self ._execute_update_pydantic_model (
288+ workspace , sql , should_raise = True
289+ )
290+ return updated_workspace
291+
275292 async def update_session (self , session : Session ) -> Optional [Session ]:
276293 sql = text (
277294 """
@@ -284,7 +301,7 @@ async def update_session(self, session: Session) -> Optional[Session]:
284301 """
285302 )
286303 # We only pass an object to respect the signature of the function
287- active_session = await self ._execute_update_pydantic_model (session , sql )
304+ active_session = await self ._execute_update_pydantic_model (session , sql , should_raise = True )
288305 return active_session
289306
290307
@@ -317,14 +334,21 @@ async def _execute_select_pydantic_model(
317334 return None
318335
319336 async def _exec_select_conditions_to_pydantic (
320- self , model_type : Type [BaseModel ], sql_command : TextClause , conditions : dict
337+ self ,
338+ model_type : Type [BaseModel ],
339+ sql_command : TextClause ,
340+ conditions : dict ,
341+ should_raise : bool = False ,
321342 ) -> Optional [List [BaseModel ]]:
322343 async with self ._async_db_engine .begin () as conn :
323344 try :
324345 result = await conn .execute (sql_command , conditions )
325346 return await self ._dump_result_to_pydantic_model (model_type , result )
326347 except Exception as e :
327348 logger .error (f"Failed to select model with conditions: { model_type } ." , error = str (e ))
349+ # Exposes errors to the caller
350+ if should_raise :
351+ raise e
328352 return None
329353
330354 async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
@@ -382,17 +406,19 @@ async def get_workspaces(self) -> List[WorkspaceActive]:
382406 workspaces = await self ._execute_select_pydantic_model (WorkspaceActive , sql )
383407 return workspaces
384408
385- async def get_workspace_by_name (self , name : str ) -> List [Workspace ]:
409+ async def get_workspace_by_name (self , name : str ) -> Optional [Workspace ]:
386410 sql = text (
387411 """
388412 SELECT
389- id, name
413+ id, name, system_prompt
390414 FROM workspaces
391415 WHERE name = :name
392416 """
393417 )
394418 conditions = {"name" : name }
395- workspaces = await self ._exec_select_conditions_to_pydantic (Workspace , sql , conditions )
419+ workspaces = await self ._exec_select_conditions_to_pydantic (
420+ Workspace , sql , conditions , should_raise = True
421+ )
396422 return workspaces [0 ] if workspaces else None
397423
398424 async def get_sessions (self ) -> List [Session ]:
@@ -410,7 +436,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
410436 sql = text (
411437 """
412438 SELECT
413- w.id, w.name, s.id as session_id, s.last_update
439+ w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
414440 FROM sessions s
415441 INNER JOIN workspaces w ON w.id = s.active_workspace_id
416442 """
@@ -453,7 +479,11 @@ def init_session_if_not_exists(db_path: Optional[str] = None):
453479 last_update = datetime .datetime .now (datetime .timezone .utc ),
454480 )
455481 db_recorder = DbRecorder (db_path )
456- asyncio .run (db_recorder .update_session (session ))
482+ try :
483+ asyncio .run (db_recorder .update_session (session ))
484+ except Exception as e :
485+ logger .error (f"Failed to initialize session in DB: { e } " )
486+ return
457487 logger .info ("Session in DB initialized successfully." )
458488
459489
0 commit comments