1
1
import asyncio
2
- import copy
3
- import datetime
4
2
import json
5
- import uuid
6
3
from pathlib import Path
7
- from typing import AsyncGenerator , AsyncIterator , List , Optional
4
+ from typing import List , Optional
8
5
9
6
import structlog
10
- from litellm import ChatCompletionRequest , ModelResponse
11
7
from pydantic import BaseModel
12
8
from sqlalchemy import text
13
9
from sqlalchemy .ext .asyncio import create_async_engine
18
14
GetAlertsWithPromptAndOutputRow ,
19
15
GetPromptWithOutputsRow ,
20
16
)
17
+ from codegate .pipeline .base import PipelineContext
21
18
22
19
logger = structlog .get_logger ("codegate" )
23
20
alert_queue = asyncio .Queue ()
@@ -103,97 +100,51 @@ async def _insert_pydantic_model(
103
100
logger .error (f"Failed to insert model: { model } ." , error = str (e ))
104
101
return None
105
102
106
- async def record_request (
107
- self , normalized_request : ChatCompletionRequest , is_fim_request : bool , provider_str : str
108
- ) -> Optional [Prompt ]:
109
- request_str = None
110
- if isinstance (normalized_request , BaseModel ):
111
- request_str = normalized_request .model_dump_json (exclude_none = True , exclude_unset = True )
112
- else :
113
- try :
114
- request_str = json .dumps (normalized_request )
115
- except Exception as e :
116
- logger .error (f"Failed to serialize output: { normalized_request } " , error = str (e ))
117
-
118
- if request_str is None :
119
- logger .warning ("No request found to record." )
120
- return
121
-
122
- # Create a new prompt record
123
- prompt_params = Prompt (
124
- id = str (uuid .uuid4 ()), # Generate a new UUID for the prompt
125
- timestamp = datetime .datetime .now (datetime .timezone .utc ),
126
- provider = provider_str ,
127
- type = "fim" if is_fim_request else "chat" ,
128
- request = request_str ,
129
- )
103
+ async def record_request (self , prompt_params : Optional [Prompt ] = None ) -> Optional [Prompt ]:
104
+ if prompt_params is None :
105
+ return None
130
106
sql = text (
131
107
"""
132
108
INSERT INTO prompts (id, timestamp, provider, request, type)
133
109
VALUES (:id, :timestamp, :provider, :request, :type)
134
110
RETURNING *
135
111
"""
136
112
)
137
- return await self ._insert_pydantic_model (prompt_params , sql )
138
-
139
- async def _record_output (self , prompt : Prompt , output_str : str ) -> Optional [Output ]:
140
- output_params = Output (
141
- id = str (uuid .uuid4 ()),
142
- prompt_id = prompt .id ,
143
- timestamp = datetime .datetime .now (datetime .timezone .utc ),
144
- output = output_str ,
113
+ recorded_request = await self ._insert_pydantic_model (prompt_params , sql )
114
+ logger .debug (f"Recorded request: { recorded_request } " )
115
+ return recorded_request
116
+
117
+ async def record_outputs (self , outputs : List [Output ]) -> Optional [Output ]:
118
+ if not outputs :
119
+ return
120
+
121
+ first_output = outputs [0 ]
122
+ # Create a single entry on DB but encode all of the chunks in the stream as a list
123
+ # of JSON objects in the field `output`
124
+ output_db = Output (
125
+ id = first_output .id ,
126
+ prompt_id = first_output .prompt_id ,
127
+ timestamp = first_output .timestamp ,
128
+ output = first_output .output ,
145
129
)
130
+ full_outputs = []
131
+ # Just store the model respnses in the list of JSON objects.
132
+ for output in outputs :
133
+ full_outputs .append (output .output )
134
+ output_db .output = json .dumps (full_outputs )
135
+
146
136
sql = text (
147
137
"""
148
138
INSERT INTO outputs (id, prompt_id, timestamp, output)
149
139
VALUES (:id, :prompt_id, :timestamp, :output)
150
140
RETURNING *
151
141
"""
152
142
)
153
- return await self ._insert_pydantic_model (output_params , sql )
154
-
155
- async def record_output_stream (
156
- self , prompt : Prompt , model_response : AsyncIterator
157
- ) -> AsyncGenerator :
158
- output_chunks = []
159
- async for chunk in model_response :
160
- if isinstance (chunk , BaseModel ):
161
- chunk_to_record = chunk .model_dump (exclude_none = True , exclude_unset = True )
162
- output_chunks .append (chunk_to_record )
163
- elif isinstance (chunk , dict ):
164
- output_chunks .append (copy .deepcopy (chunk ))
165
- else :
166
- output_chunks .append ({"chunk" : str (chunk )})
167
- yield chunk
168
-
169
- if output_chunks :
170
- # Record the output chunks
171
- output_str = json .dumps (output_chunks )
172
- await self ._record_output (prompt , output_str )
173
-
174
- async def record_output_non_stream (
175
- self , prompt : Optional [Prompt ], model_response : ModelResponse
176
- ) -> Optional [Output ]:
177
- if prompt is None :
178
- logger .warning ("No prompt found to record output." )
179
- return
143
+ recorded_output = await self ._insert_pydantic_model (output_db , sql )
144
+ logger .debug (f"Recorded output: { recorded_output } " )
145
+ return recorded_output
180
146
181
- output_str = None
182
- if isinstance (model_response , BaseModel ):
183
- output_str = model_response .model_dump_json (exclude_none = True , exclude_unset = True )
184
- else :
185
- try :
186
- output_str = json .dumps (model_response )
187
- except Exception as e :
188
- logger .error (f"Failed to serialize output: { model_response } " , error = str (e ))
189
-
190
- if output_str is None :
191
- logger .warning ("No output found to record." )
192
- return
193
-
194
- return await self ._record_output (prompt , output_str )
195
-
196
- async def record_alerts (self , alerts : List [Alert ]) -> None :
147
+ async def record_alerts (self , alerts : List [Alert ]) -> List [Alert ]:
197
148
if not alerts :
198
149
return
199
150
sql = text (
@@ -208,15 +159,33 @@ async def record_alerts(self, alerts: List[Alert]) -> None:
208
159
"""
209
160
)
210
161
# We can insert each alert independently in parallel.
162
+ alerts_tasks = []
211
163
async with asyncio .TaskGroup () as tg :
212
164
for alert in alerts :
213
165
try :
214
166
result = tg .create_task (self ._insert_pydantic_model (alert , sql ))
215
- if result and alert .trigger_category == "critical" :
216
- await alert_queue .put (f"New alert detected: { alert .timestamp } " )
167
+ alerts_tasks .append (result )
217
168
except Exception as e :
218
169
logger .error (f"Failed to record alert: { alert } ." , error = str (e ))
219
- return None
170
+
171
+ recorded_alerts = []
172
+ for alert_coro in alerts_tasks :
173
+ alert_result = alert_coro .result ()
174
+ recorded_alerts .append (alert_result )
175
+ if alert_result and alert_result .trigger_category == "critical" :
176
+ await alert_queue .put (f"New alert detected: { alert .timestamp } " )
177
+
178
+ logger .debug (f"Recorded alerts: { recorded_alerts } " )
179
+ return recorded_alerts
180
+
181
+ async def record_context (self , context : PipelineContext ) -> None :
182
+ logger .info (
183
+ f"Recording context in DB. Output chunks: { len (context .output_responses )} . "
184
+ f"Alerts: { len (context .alerts_raised )} ."
185
+ )
186
+ await self .record_request (context .input_request )
187
+ await self .record_outputs (context .output_responses )
188
+ await self .record_alerts (context .alerts_raised )
220
189
221
190
222
191
class DbReader (DbCodeGate ):
0 commit comments