Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.

Commit 7154d67

Browse files
Merge pull request #331 from stacklok/move-db-to-pipeline
Copilot DB integration. Keep DB objects in context to record at the end.
2 parents 35c583f + dfb5b8f commit 7154d67

File tree

13 files changed

+280
-182
lines changed

13 files changed

+280
-182
lines changed

src/codegate/dashboard/post_processing.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,39 @@ async def parse_output(output_str: str) -> Tuple[Optional[str], Optional[str]]:
8686
logger.warning(f"Error parsing output: {output_str}. {e}")
8787
return None, None
8888

89-
output_message = ""
89+
def _parse_single_output(single_output: dict) -> str:
90+
single_chat_id = single_output.get("id")
91+
single_output_message = ""
92+
for choice in single_output.get("choices", []):
93+
if not isinstance(choice, dict):
94+
continue
95+
content_dict = choice.get("delta", {}) or choice.get("message", {})
96+
single_output_message += content_dict.get("content", "")
97+
return single_output_message, single_chat_id
98+
99+
full_output_message = ""
90100
chat_id = None
91101
if isinstance(output, list):
92102
for output_chunk in output:
93-
if not isinstance(output_chunk, dict):
94-
continue
95-
chat_id = chat_id or output_chunk.get("id")
96-
for choice in output_chunk.get("choices", []):
97-
if not isinstance(choice, dict):
98-
continue
99-
delta_dict = choice.get("delta", {})
100-
output_message += delta_dict.get("content", "")
103+
output_message, output_chat_id = "", None
104+
if isinstance(output_chunk, dict):
105+
output_message, output_chat_id = _parse_single_output(output_chunk)
106+
elif isinstance(output_chunk, str):
107+
try:
108+
output_decoded = json.loads(output_chunk)
109+
output_message, output_chat_id = _parse_single_output(output_decoded)
110+
except Exception:
111+
logger.error(f"Error reading chunk: {output_chunk}")
112+
else:
113+
logger.warning(
114+
f"Could not handle output: {output_chunk}", out_type=type(output_chunk)
115+
)
116+
chat_id = chat_id or output_chat_id
117+
full_output_message += output_message
101118
elif isinstance(output, dict):
102-
chat_id = chat_id or output.get("id")
103-
for choice in output.get("choices", []):
104-
if not isinstance(choice, dict):
105-
continue
106-
output_message += choice.get("message", {}).get("content", "")
119+
full_output_message, chat_id = _parse_single_output(output)
107120

108-
return output_message, chat_id
121+
return full_output_message, chat_id
109122

110123

111124
async def _get_question_answer(
@@ -124,19 +137,23 @@ async def _get_question_answer(
124137
output_msg_str, chat_id = output_task.result()
125138

126139
# If we couldn't parse the request or output, return None
127-
if not request_msg_str or not output_msg_str or not chat_id:
140+
if not request_msg_str:
128141
return None, None
129142

130143
request_message = ChatMessage(
131144
message=request_msg_str,
132145
timestamp=row.timestamp,
133146
message_id=row.id,
134147
)
135-
output_message = ChatMessage(
136-
message=output_msg_str,
137-
timestamp=row.output_timestamp,
138-
message_id=row.output_id,
139-
)
148+
if output_msg_str:
149+
output_message = ChatMessage(
150+
message=output_msg_str,
151+
timestamp=row.output_timestamp,
152+
message_id=row.output_id,
153+
)
154+
else:
155+
output_message = None
156+
chat_id = row.id
140157
return QuestionAnswer(question=request_message, answer=output_message), chat_id
141158

142159

src/codegate/dashboard/request_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class QuestionAnswer(BaseModel):
2222
"""
2323

2424
question: ChatMessage
25-
answer: ChatMessage
25+
answer: Optional[ChatMessage]
2626

2727

2828
class PartialConversation(BaseModel):

src/codegate/db/connection.py

Lines changed: 52 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import asyncio
2-
import copy
3-
import datetime
42
import json
5-
import uuid
63
from pathlib import Path
7-
from typing import AsyncGenerator, AsyncIterator, List, Optional
4+
from typing import List, Optional
85

96
import structlog
10-
from litellm import ChatCompletionRequest, ModelResponse
117
from pydantic import BaseModel
128
from sqlalchemy import text
139
from sqlalchemy.ext.asyncio import create_async_engine
@@ -18,6 +14,7 @@
1814
GetAlertsWithPromptAndOutputRow,
1915
GetPromptWithOutputsRow,
2016
)
17+
from codegate.pipeline.base import PipelineContext
2118

2219
logger = structlog.get_logger("codegate")
2320
alert_queue = asyncio.Queue()
@@ -103,97 +100,51 @@ async def _insert_pydantic_model(
103100
logger.error(f"Failed to insert model: {model}.", error=str(e))
104101
return None
105102

106-
async def record_request(
107-
self, normalized_request: ChatCompletionRequest, is_fim_request: bool, provider_str: str
108-
) -> Optional[Prompt]:
109-
request_str = None
110-
if isinstance(normalized_request, BaseModel):
111-
request_str = normalized_request.model_dump_json(exclude_none=True, exclude_unset=True)
112-
else:
113-
try:
114-
request_str = json.dumps(normalized_request)
115-
except Exception as e:
116-
logger.error(f"Failed to serialize output: {normalized_request}", error=str(e))
117-
118-
if request_str is None:
119-
logger.warning("No request found to record.")
120-
return
121-
122-
# Create a new prompt record
123-
prompt_params = Prompt(
124-
id=str(uuid.uuid4()), # Generate a new UUID for the prompt
125-
timestamp=datetime.datetime.now(datetime.timezone.utc),
126-
provider=provider_str,
127-
type="fim" if is_fim_request else "chat",
128-
request=request_str,
129-
)
103+
async def record_request(self, prompt_params: Optional[Prompt] = None) -> Optional[Prompt]:
104+
if prompt_params is None:
105+
return None
130106
sql = text(
131107
"""
132108
INSERT INTO prompts (id, timestamp, provider, request, type)
133109
VALUES (:id, :timestamp, :provider, :request, :type)
134110
RETURNING *
135111
"""
136112
)
137-
return await self._insert_pydantic_model(prompt_params, sql)
138-
139-
async def _record_output(self, prompt: Prompt, output_str: str) -> Optional[Output]:
140-
output_params = Output(
141-
id=str(uuid.uuid4()),
142-
prompt_id=prompt.id,
143-
timestamp=datetime.datetime.now(datetime.timezone.utc),
144-
output=output_str,
113+
recorded_request = await self._insert_pydantic_model(prompt_params, sql)
114+
logger.debug(f"Recorded request: {recorded_request}")
115+
return recorded_request
116+
117+
async def record_outputs(self, outputs: List[Output]) -> Optional[Output]:
118+
if not outputs:
119+
return
120+
121+
first_output = outputs[0]
122+
# Create a single entry on DB but encode all of the chunks in the stream as a list
123+
# of JSON objects in the field `output`
124+
output_db = Output(
125+
id=first_output.id,
126+
prompt_id=first_output.prompt_id,
127+
timestamp=first_output.timestamp,
128+
output=first_output.output,
145129
)
130+
full_outputs = []
131+
# Just store the model respnses in the list of JSON objects.
132+
for output in outputs:
133+
full_outputs.append(output.output)
134+
output_db.output = json.dumps(full_outputs)
135+
146136
sql = text(
147137
"""
148138
INSERT INTO outputs (id, prompt_id, timestamp, output)
149139
VALUES (:id, :prompt_id, :timestamp, :output)
150140
RETURNING *
151141
"""
152142
)
153-
return await self._insert_pydantic_model(output_params, sql)
154-
155-
async def record_output_stream(
156-
self, prompt: Prompt, model_response: AsyncIterator
157-
) -> AsyncGenerator:
158-
output_chunks = []
159-
async for chunk in model_response:
160-
if isinstance(chunk, BaseModel):
161-
chunk_to_record = chunk.model_dump(exclude_none=True, exclude_unset=True)
162-
output_chunks.append(chunk_to_record)
163-
elif isinstance(chunk, dict):
164-
output_chunks.append(copy.deepcopy(chunk))
165-
else:
166-
output_chunks.append({"chunk": str(chunk)})
167-
yield chunk
168-
169-
if output_chunks:
170-
# Record the output chunks
171-
output_str = json.dumps(output_chunks)
172-
await self._record_output(prompt, output_str)
173-
174-
async def record_output_non_stream(
175-
self, prompt: Optional[Prompt], model_response: ModelResponse
176-
) -> Optional[Output]:
177-
if prompt is None:
178-
logger.warning("No prompt found to record output.")
179-
return
143+
recorded_output = await self._insert_pydantic_model(output_db, sql)
144+
logger.debug(f"Recorded output: {recorded_output}")
145+
return recorded_output
180146

181-
output_str = None
182-
if isinstance(model_response, BaseModel):
183-
output_str = model_response.model_dump_json(exclude_none=True, exclude_unset=True)
184-
else:
185-
try:
186-
output_str = json.dumps(model_response)
187-
except Exception as e:
188-
logger.error(f"Failed to serialize output: {model_response}", error=str(e))
189-
190-
if output_str is None:
191-
logger.warning("No output found to record.")
192-
return
193-
194-
return await self._record_output(prompt, output_str)
195-
196-
async def record_alerts(self, alerts: List[Alert]) -> None:
147+
async def record_alerts(self, alerts: List[Alert]) -> List[Alert]:
197148
if not alerts:
198149
return
199150
sql = text(
@@ -208,15 +159,33 @@ async def record_alerts(self, alerts: List[Alert]) -> None:
208159
"""
209160
)
210161
# We can insert each alert independently in parallel.
162+
alerts_tasks = []
211163
async with asyncio.TaskGroup() as tg:
212164
for alert in alerts:
213165
try:
214166
result = tg.create_task(self._insert_pydantic_model(alert, sql))
215-
if result and alert.trigger_category == "critical":
216-
await alert_queue.put(f"New alert detected: {alert.timestamp}")
167+
alerts_tasks.append(result)
217168
except Exception as e:
218169
logger.error(f"Failed to record alert: {alert}.", error=str(e))
219-
return None
170+
171+
recorded_alerts = []
172+
for alert_coro in alerts_tasks:
173+
alert_result = alert_coro.result()
174+
recorded_alerts.append(alert_result)
175+
if alert_result and alert_result.trigger_category == "critical":
176+
await alert_queue.put(f"New alert detected: {alert.timestamp}")
177+
178+
logger.debug(f"Recorded alerts: {recorded_alerts}")
179+
return recorded_alerts
180+
181+
async def record_context(self, context: PipelineContext) -> None:
182+
logger.info(
183+
f"Recording context in DB. Output chunks: {len(context.output_responses)}. "
184+
f"Alerts: {len(context.alerts_raised)}."
185+
)
186+
await self.record_request(context.input_request)
187+
await self.record_outputs(context.output_responses)
188+
await self.record_alerts(context.alerts_raised)
220189

221190

222191
class DbReader(DbCodeGate):

0 commit comments

Comments
 (0)