11import asyncio
22import json
3+ import uuid
34from pathlib import Path
45from typing import List , Optional , Type
56
67import structlog
78from alembic import command as alembic_command
89from alembic .config import Config as AlembicConfig
910from pydantic import BaseModel
10- from sqlalchemy import TextClause , text
11+ from sqlalchemy import CursorResult , TextClause , text
1112from sqlalchemy .exc import OperationalError
1213from sqlalchemy .ext .asyncio import create_async_engine
1314
1415from codegate .db .fim_cache import FimCache
1516from codegate .db .models import (
17+ ActiveWorkspace ,
1618 Alert ,
1719 GetAlertsWithPromptAndOutputRow ,
1820 GetPromptWithOutputsRow ,
1921 Output ,
2022 Prompt ,
23+ Session ,
2124 Workspace ,
25+ WorkspaceActive ,
2226)
2327from codegate .pipeline .base import PipelineContext
2428
@@ -76,10 +80,14 @@ async def _execute_update_pydantic_model(
7680 async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
7781 if prompt_params is None :
7882 return None
83+ # Get the active workspace to store the request
84+ active_workspace = await DbReader ().get_active_workspace ()
85+ workspace_id = active_workspace .id if active_workspace else "1"
86+ prompt_params .workspace_id = workspace_id
7987 sql = text (
8088 """
81- INSERT INTO prompts (id, timestamp, provider, request, type)
82- VALUES (:id, :timestamp, :provider, :request, :type)
89+ INSERT INTO prompts (id, timestamp, provider, request, type, workspace_id )
90+ VALUES (:id, :timestamp, :provider, :request, :type, :workspace_id )
8391 RETURNING *
8492 """
8593 )
@@ -224,26 +232,73 @@ async def record_context(self, context: Optional[PipelineContext]) -> None:
224232 except Exception as e :
225233 logger .error (f"Failed to record context: { context } ." , error = str (e ))
226234
235+ async def add_workspace (self , workspace_name : str ) -> Optional [Workspace ]:
236+ workspace = Workspace (id = str (uuid .uuid4 ()), name = workspace_name )
237+ sql = text (
238+ """
239+ INSERT INTO workspaces (id, name)
240+ VALUES (:id, :name)
241+ RETURNING *
242+ """
243+ )
244+ added_workspace = await self ._execute_update_pydantic_model (workspace , sql )
245+ return added_workspace
246+
247+ async def update_session (self , session : Session ) -> Optional [Session ]:
248+ sql = text (
249+ """
250+ INSERT INTO sessions (id, active_workspace_id, last_update)
251+ VALUES (:id, :active_workspace_id, :last_update)
252+ ON CONFLICT (id) DO UPDATE SET
253+ active_workspace_id = excluded.active_workspace_id, last_update = excluded.last_update
254+ WHERE id = excluded.id
255+ RETURNING *
256+ """
257+ )
258+ # We only pass an object to respect the signature of the function
259+ active_session = await self ._execute_update_pydantic_model (session , sql )
260+ return active_session
261+
227262
228263class DbReader (DbCodeGate ):
229264
230265 def __init__ (self , sqlite_path : Optional [str ] = None ):
231266 super ().__init__ (sqlite_path )
232267
268+ async def _dump_result_to_pydantic_model (
269+ self , model_type : Type [BaseModel ], result : CursorResult
270+ ) -> Optional [List [BaseModel ]]:
271+ try :
272+ if not result :
273+ return None
274+ rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
275+ return rows
276+ except Exception as e :
277+ logger .error (f"Failed to dump to pydantic model: { model_type } ." , error = str (e ))
278+ return None
279+
233280 async def _execute_select_pydantic_model (
234281 self , model_type : Type [BaseModel ], sql_command : TextClause
235- ) -> Optional [BaseModel ]:
282+ ) -> Optional [List [ BaseModel ] ]:
236283 async with self ._async_db_engine .begin () as conn :
237284 try :
238285 result = await conn .execute (sql_command )
239- if not result :
240- return None
241- rows = [model_type (** row ._asdict ()) for row in result .fetchall () if row ]
242- return rows
286+ return await self ._dump_result_to_pydantic_model (model_type , result )
243287 except Exception as e :
244288 logger .error (f"Failed to select model: { model_type } ." , error = str (e ))
245289 return None
246290
291+ async def _exec_select_conditions_to_pydantic (
292+ self , model_type : Type [BaseModel ], sql_command : TextClause , conditions : dict
293+ ) -> Optional [List [BaseModel ]]:
294+ async with self ._async_db_engine .begin () as conn :
295+ try :
296+ result = await conn .execute (sql_command , conditions )
297+ return await self ._dump_result_to_pydantic_model (model_type , result )
298+ except Exception as e :
299+ logger .error (f"Failed to select model with conditions: { model_type } ." , error = str (e ))
300+ return None
301+
247302 async def get_prompts_with_output (self ) -> List [GetPromptWithOutputsRow ]:
248303 sql = text (
249304 """
@@ -287,18 +342,54 @@ async def get_alerts_with_prompt_and_output(self) -> List[GetAlertsWithPromptAnd
287342 prompts = await self ._execute_select_pydantic_model (GetAlertsWithPromptAndOutputRow , sql )
288343 return prompts
289344
290- async def get_workspaces (self ) -> List [Workspace ]:
345+ async def get_workspaces (self ) -> List [WorkspaceActive ]:
291346 sql = text (
292347 """
293348 SELECT
294- id, name, is_active
349+ w.id, w.name, s.active_workspace_id
350+ FROM workspaces w
351+ LEFT JOIN sessions s ON w.id = s.active_workspace_id
352+ """
353+ )
354+ workspaces = await self ._execute_select_pydantic_model (WorkspaceActive , sql )
355+ return workspaces
356+
357+ async def get_workspace_by_name (self , name : str ) -> List [Workspace ]:
358+ sql = text (
359+ """
360+ SELECT
361+ id, name
295362 FROM workspaces
296- ORDER BY is_active DESC
363+ WHERE name = :name
297364 """
298365 )
299- workspaces = await self ._execute_select_pydantic_model (Workspace , sql )
366+ conditions = {"name" : name }
367+ workspaces = await self ._exec_select_conditions_to_pydantic (Workspace , sql , conditions )
300368 return workspaces
301369
370+ async def get_sessions (self ) -> List [Session ]:
371+ sql = text (
372+ """
373+ SELECT
374+ id, active_workspace_id, last_update
375+ FROM sessions
376+ """
377+ )
378+ sessions = await self ._execute_select_pydantic_model (Session , sql )
379+ return sessions
380+
381+ async def get_active_workspace (self ) -> Optional [ActiveWorkspace ]:
382+ sql = text (
383+ """
384+ SELECT
385+ w.id, w.name, s.id as session_id, s.last_update
386+ FROM sessions s
387+ INNER JOIN workspaces w ON w.id = s.active_workspace_id
388+ """
389+ )
390+ active_workspace = await self ._execute_select_pydantic_model (ActiveWorkspace , sql )
391+ return active_workspace [0 ] if active_workspace else None
392+
302393
303394def init_db_sync (db_path : Optional [str ] = None ):
304395 """DB will be initialized in the constructor in case it doesn't exist."""
@@ -320,5 +411,22 @@ def init_db_sync(db_path: Optional[str] = None):
320411 logger .info ("DB initialized successfully." )
321412
322413
414+ def init_session_if_not_exists (db_path : Optional [str ] = None ):
415+ import datetime
416+ db_reader = DbReader (db_path )
417+ sessions = asyncio .run (db_reader .get_sessions ())
418+ # If there are no sessions, create a new one
419+ # TODO: For the moment there's a single session. If it already exists, we don't create a new one
420+ if not sessions :
421+ session = Session (
422+ id = str (uuid .uuid4 ()),
423+ active_workspace_id = "1" ,
424+ last_update = datetime .datetime .now (datetime .timezone .utc )
425+ )
426+ db_recorder = DbRecorder (db_path )
427+ asyncio .run (db_recorder .update_session (session ))
428+ logger .info ("Session in DB initialized successfully." )
429+
430+
323431if __name__ == "__main__" :
324432 init_db_sync ()
0 commit comments