11import asyncio
22import json
3+ import uuid
34from pathlib import Path
45from typing import List , Optional , Type
56
67import structlog
78from alembic import command as alembic_command
89from alembic .config import Config as AlembicConfig
9- from pydantic import BaseModel
10- from sqlalchemy import TextClause , text
10+ from pydantic import BaseModel , ValidationError
11+ from sqlalchemy import CursorResult , TextClause , text
1112from sqlalchemy .exc import OperationalError
1213from sqlalchemy .ext .asyncio import create_async_engine
1314
1415from codegate .db .fim_cache import FimCache
1516from codegate .db .models import (
17+ ActiveWorkspace ,
1618 Alert ,
1719 GetAlertsWithPromptAndOutputRow ,
1820 GetPromptWithOutputsRow ,
1921 Output ,
2022 Prompt ,
23+ Session ,
24+ Workspace ,
25+ WorkspaceActive ,
2126)
2227from codegate .pipeline .base import PipelineContext
2328
@@ -75,10 +80,14 @@ async def _execute_update_pydantic_model(
7580 async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
7681 if prompt_params is None :
7782 return None
83+ # Get the active workspace to store the request
84+ active_workspace = await DbReader ().get_active_workspace ()
85+ workspace_id = active_workspace .id if active_workspace else "1"
86+ prompt_params .workspace_id = workspace_id
7887 sql = text (
7988 """
80- INSERT INTO prompts (id, timestamp, provider, request, type)
81- VALUES (:id, :timestamp, :provider, :request, :type)
89+ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id )
90+ VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id )
8291 RETURNING *
8392 """
8493 )
@@ -223,26 +232,78 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
223232 except Exception as e :
224233 logger .error (f"Failed to record context: { context } ." , error = str (e ))
225234
235+ async def add_workspace (self , workspace_name : str ) -> Optional [Workspace ]:
236+ try :
237+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
238+ except ValidationError as e :
239+ logger .error (f"Failed to create workspace with name: { workspace_name } : { str (e )} " )
240+ return None
241+
242+ sql = text (
243+ """
244+ INSERT INTO workspaces (id, name)
245+ VALUES (:id, :name)
246+ RETURNING *
247+ """
248+ )
249+ added_workspace = await self ._execute_update_pydantic_model (workspace , sql )
250+ return added_workspace
251+
252+ async def update_session (self , session : Session ) -> Optional [Session ]:
253+ sql = text (
254+ """
255+ INSERT INTO sessions (id, active_workspace_id, last_update)
256+ VALUES (:id, :active_workspace_id, :last_update)
257+ ON CONFLICT (id) DO UPDATE SET
258+ active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
259+ WHERE id = excluded.id
260+ RETURNING *
261+ """
262+ )
263+ # We only pass an object to respect the signature of the function
264+ active_session = await self ._execute_update_pydantic_model (session , sql )
265+ return active_session
266+
226267
227268class DbReader (DbCodeGate ):
228269
229270 def __init__ (self , sqlite_path : Optional [str ] = None ):
230271 super ().__init__ (sqlite_path )
231272
273+ async def _dump_result_to_pydantic_model (
274+ self , model_type : Type [BaseModel ], result : CursorResult
275+ ) -> Optional [List [BaseModel ]]:
276+ try :
277+ if not result :
278+ return None
279+ rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
280+ return rows
281+ except Exception as e :
282+ logger .error (f"Failed to dump to pydantic model: { model_type } ." , error = str (e ))
283+ return None
284+
232285 async def _execute_select_pydantic_model (
233286 self , model_type : Type [BaseModel ], sql_command : TextClause
234- ) -> Optional [BaseModel ]:
287+ ) -> Optional [List [ BaseModel ] ]:
235288 async with self ._async_db_engine .begin () as conn :
236289 try :
237290 result = await conn .execute (sql_command )
238- if not result :
239- return None
240- rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
241- return rows
291+ return await self ._dump_result_to_pydantic_model (model_type , result )
242292 except Exception as e :
243293 logger .error (f"Failed to select model: { model_type } ." , error = str (e ))
244294 return None
245295
296+ async def _exec_select_conditions_to_pydantic (
297+ self , model_type : Type [BaseModel ], sql_command : TextClause , conditions : dict
298+ ) -> Optional [List [BaseModel ]]:
299+ async with self ._async_db_engine .begin () as conn :
300+ try :
301+ result = await conn .execute (sql_command , conditions )
302+ return await self ._dump_result_to_pydantic_model (model_type , result )
303+ except Exception as e :
304+ logger .error (f"Failed to select model with conditions: { model_type } ." , error = str (e ))
305+ return None
306+
246307 async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
247308 sql = text (
248309 """
@@ -286,6 +347,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
286347 prompts = await self ._execute_select_pydantic_model (GetAlertsWithPromptAndOutputRow , sql )
287348 return prompts
288349
350+ async def get_workspaces (self ) -> List [WorkspaceActive ]:
351+ sql = text (
352+ """
353+ SELECT
354+ w.id, w.name, s.active_workspace_id
355+ FROM workspaces w
356+ LEFT JOIN sessions s ON w.id = s.active_workspace_id
357+ """
358+ )
359+ workspaces = await self ._execute_select_pydantic_model (WorkspaceActive , sql )
360+ return workspaces
361+
362+ async def get_workspace_by_name (self , name : str ) -> List [Workspace ]:
363+ sql = text (
364+ """
365+ SELECT
366+ id, name
367+ FROM workspaces
368+ WHERE name = :name
369+ """
370+ )
371+ conditions = {"name" : name }
372+ workspaces = await self ._exec_select_conditions_to_pydantic (Workspace , sql , conditions )
373+ return workspaces [0 ] if workspaces else None
374+
375+ async def get_sessions (self ) -> List [Session ]:
376+ sql = text (
377+ """
378+ SELECT
379+ id, active_workspace_id, last_update
380+ FROM sessions
381+ """
382+ )
383+ sessions = await self ._execute_select_pydantic_model (Session , sql )
384+ return sessions
385+
386+ async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
387+ sql = text (
388+ """
389+ SELECT
390+ w.id, w.name, s.id as session_id, s.last_update
391+ FROM sessions s
392+ INNER JOIN workspaces w ON w.id = s.active_workspace_id
393+ """
394+ )
395+ active_workspace = await self ._execute_select_pydantic_model (ActiveWorkspace , sql )
396+ return active_workspace [0 ] if active_workspace else None
397+
289398
290399def init_db_sync (db_path : Optional [str ] = None ):
291400 """DB will be initialized in the constructor in case it doesn't exist."""
@@ -307,5 +416,23 @@ def init_db_sync(db_path: Optional[str] = None):
307416 logger .info ("DB initialized successfully." )
308417
309418
419+ def init_session_if_not_exists (db_path : Optional [str ] = None ):
420+ import datetime
421+
422+ db_reader = DbReader (db_path )
423+ sessions = asyncio .run (db_reader .get_sessions ())
424+ # If there are no sessions, create a new one
425+ # TODO: For the moment there's a single session. If it already exists, we don't create a new one
426+ if not sessions :
427+ session = Session (
428+ id = str (uuid .uuid4 ()),
429+ active_workspace_id = "1" ,
430+ last_update = datetime .datetime .now (datetime .timezone .utc ),
431+ )
432+ db_recorder = DbRecorder (db_path )
433+ asyncio .run (db_recorder .update_session (session ))
434+ logger .info ("Session in DB initialized successfully." )
435+
436+
310437if __name__ == "__main__" :
311438 init_db_sync ()
0 commit comments