11import asyncio
22import json
33import re
4- from typing import List , Optional , Tuple , Union
4+ from collections import defaultdict
5+ from typing import List , Optional , Union
56
67import structlog
78
89from codegate .dashboard .request_models import (
910 AlertConversation ,
1011 ChatMessage ,
1112 Conversation ,
12- PartialConversation ,
13+ PartialQuestionAnswer ,
14+ PartialQuestions ,
1315 QuestionAnswer ,
1416)
1517from codegate .db .models import GetAlertsWithPromptAndOutputRow , GetPromptWithOutputsRow
@@ -74,60 +76,57 @@ async def parse_request(request_str: str) -> Optional[str]:
7476 return None
7577
7678 # Only respond with the latest message
77- return messages [ - 1 ]
79+ return messages
7880
7981
80- async def parse_output (output_str : str ) -> Tuple [ Optional [str ], Optional [ str ] ]:
82+ async def parse_output (output_str : str ) -> Optional [str ]:
8183 """
82- Parse the output string from the pipeline and return the message and chat_id .
84+ Parse the output string from the pipeline and return the message.
8385 """
8486 try :
8587 if output_str is None :
86- return None , None
88+ return None
8789
8890 output = json .loads (output_str )
8991 except Exception as e :
9092 logger .warning (f"Error parsing output: { output_str } . { e } " )
91- return None , None
93+ return None
9294
9395 def _parse_single_output (single_output : dict ) -> str :
94- single_chat_id = single_output .get ("id" )
9596 single_output_message = ""
9697 for choice in single_output .get ("choices" , []):
9798 if not isinstance (choice , dict ):
9899 continue
99100 content_dict = choice .get ("delta" , {}) or choice .get ("message" , {})
100101 single_output_message += content_dict .get ("content" , "" )
101- return single_output_message , single_chat_id
102+ return single_output_message
102103
103104 full_output_message = ""
104- chat_id = None
105105 if isinstance (output , list ):
106106 for output_chunk in output :
107- output_message , output_chat_id = "" , None
107+ output_message = ""
108108 if isinstance (output_chunk , dict ):
109- output_message , output_chat_id = _parse_single_output (output_chunk )
109+ output_message = _parse_single_output (output_chunk )
110110 elif isinstance (output_chunk , str ):
111111 try :
112112 output_decoded = json .loads (output_chunk )
113- output_message , output_chat_id = _parse_single_output (output_decoded )
113+ output_message = _parse_single_output (output_decoded )
114114 except Exception :
115115 logger .error (f"Error reading chunk: { output_chunk } " )
116116 else :
117117 logger .warning (
118118 f"Could not handle output: { output_chunk } " , out_type = type (output_chunk )
119119 )
120- chat_id = chat_id or output_chat_id
121120 full_output_message += output_message
122121 elif isinstance (output , dict ):
123- full_output_message , chat_id = _parse_single_output (output )
122+ full_output_message = _parse_single_output (output )
124123
125- return full_output_message , chat_id
124+ return full_output_message
126125
127126
128127async def _get_question_answer (
129128 row : Union [GetPromptWithOutputsRow , GetAlertsWithPromptAndOutputRow ]
130- ) -> Tuple [ Optional [QuestionAnswer ], Optional [ str ] ]:
129+ ) -> Optional [PartialQuestionAnswer ]:
131130 """
132131 Parse a row from the get_prompt_with_outputs query and return a PartialConversation
133132
@@ -137,17 +136,19 @@ async def _get_question_answer(
137136 request_task = tg .create_task (parse_request (row .request ))
138137 output_task = tg .create_task (parse_output (row .output ))
139138
140- request_msg_str = request_task .result ()
141- output_msg_str , chat_id = output_task .result ()
139+ request_user_msgs = request_task .result ()
140+ output_msg_str = output_task .result ()
142141
143- # If we couldn't parse the request or output , return None
144- if not request_msg_str :
145- return None , None
142+ # If we couldn't parse the request, return None
143+ if not request_user_msgs :
144+ return None
146145
147- request_message = ChatMessage (
148- message = request_msg_str ,
146+ request_message = PartialQuestions (
147+ messages = request_user_msgs ,
149148 timestamp = row .timestamp ,
150149 message_id = row .id ,
150+ provider = row .provider ,
151+ type = row .type ,
151152 )
152153 if output_msg_str :
153154 output_message = ChatMessage (
@@ -157,28 +158,7 @@ async def _get_question_answer(
157158 )
158159 else :
159160 output_message = None
160- chat_id = row .id
161- return QuestionAnswer (question = request_message , answer = output_message ), chat_id
162-
163-
164- async def parse_get_prompt_with_output (
165- row : GetPromptWithOutputsRow ,
166- ) -> Optional [PartialConversation ]:
167- """
168- Parse a row from the get_prompt_with_outputs query and return a PartialConversation
169-
170- The row contains the raw request and output strings from the pipeline.
171- """
172- question_answer , chat_id = await _get_question_answer (row )
173- if not question_answer or not chat_id :
174- return None
175- return PartialConversation (
176- question_answer = question_answer ,
177- provider = row .provider ,
178- type = row .type ,
179- chat_id = chat_id ,
180- request_timestamp = row .timestamp ,
181- )
161+ return PartialQuestionAnswer (partial_questions = request_message , answer = output_message )
182162
183163
184164def parse_question_answer (input_text : str ) -> str :
@@ -195,50 +175,135 @@ def parse_question_answer(input_text: str) -> str:
195175 return input_text
196176
197177
178+ def _group_partial_messages (pq_list : List [PartialQuestions ]) -> List [List [PartialQuestions ]]:
179+ """
180+ A PartialQuestion is an object that contains several user messages provided from a
181+ chat conversation. Example:
182+ - PartialQuestion(messages=["Hello"], timestamp=2022-01-01T00:00:00Z)
183+ - PartialQuestion(messages=["Hello", "How are you?"], timestamp=2022-01-01T00:00:01Z)
184+ In the above example both PartialQuestions are part of the same conversation and should be
185+ matched together.
186+ Group PartialQuestions objects such that:
187+ - If one PartialQuestion (pq) is a subset of another pq's messages, group them together.
188+ - If multiple subsets exist for the same superset, choose only the one
189+ closest in timestamp to the superset.
190+ - Leave any unpaired pq by itself.
191+ - Finally, sort the resulting groups by the earliest timestamp in each group.
192+ """
193+ # 1) Sort by length of messages descending (largest/most-complete first),
194+ # then by timestamp ascending for stable processing.
195+ pq_list_sorted = sorted (pq_list , key = lambda x : (- len (x .messages ), x .timestamp ))
196+
197+ used = set ()
198+ groups = []
199+
200+ # 2) Iterate in order of "largest messages first"
201+ for sup in pq_list_sorted :
202+ if sup .message_id in used :
203+ continue # Already grouped
204+
205+ # Find all potential subsets of 'sup' that are not yet used
206+ # (If sup's messages == sub's messages, that also counts, because sub ⊆ sup)
207+ possible_subsets = []
208+ for sub in pq_list_sorted :
209+ if sub .message_id == sup .message_id :
210+ continue
211+ if sub .message_id in used :
212+ continue
213+ if (
214+ set (sub .messages ).issubset (set (sup .messages ))
215+ and sub .provider == sup .provider
216+ and set (sub .messages ) != set (sup .messages )
217+ ):
218+ possible_subsets .append (sub )
219+
220+ # 3) If there are no subsets, this sup stands alone
221+ if not possible_subsets :
222+ groups .append ([sup ])
223+ used .add (sup .message_id )
224+ else :
225+ # 4) Group subsets by messages to discard duplicates e.g.: 2 subsets with single 'hello'
226+ subs_group_by_messages = defaultdict (list )
227+ for q in possible_subsets :
228+ subs_group_by_messages [tuple (q .messages )].append (q )
229+
230+ new_group = [sup ]
231+ used .add (sup .message_id )
232+ for subs_same_message in subs_group_by_messages .values ():
233+ # If more than one pick the one subset closest in time to sup
234+ closest_subset = min (
235+ subs_same_message , key = lambda s : abs (s .timestamp - sup .timestamp )
236+ )
237+ new_group .append (closest_subset )
238+ used .add (closest_subset .message_id )
239+ groups .append (new_group )
240+
241+ # 5) Sort the groups by the earliest timestamp within each group
242+ groups .sort (key = lambda g : min (pq .timestamp for pq in g ))
243+ return groups
244+
245+
246+ def _get_question_answer_from_partial (
247+ partial_question_answer : PartialQuestionAnswer ,
248+ ) -> QuestionAnswer :
249+ """
250+ Get a QuestionAnswer object from a PartialQuestionAnswer object.
251+ """
252+ # Get the last user message as the question
253+ question = ChatMessage (
254+ message = partial_question_answer .partial_questions .messages [- 1 ],
255+ timestamp = partial_question_answer .partial_questions .timestamp ,
256+ message_id = partial_question_answer .partial_questions .message_id ,
257+ )
258+
259+ return QuestionAnswer (question = question , answer = partial_question_answer .answer )
260+
261+
198262async def match_conversations (
199- partial_conversations : List [Optional [PartialConversation ]],
263+ partial_question_answers : List [Optional [PartialQuestionAnswer ]],
200264) -> List [Conversation ]:
201265 """
202266 Match partial conversations to form a complete conversation.
203267 """
204- convers = {}
205- for partial_conversation in partial_conversations :
206- if not partial_conversation :
207- continue
208-
209- # Group by chat_id
210- if partial_conversation .chat_id not in convers :
211- convers [partial_conversation .chat_id ] = []
212- convers [partial_conversation .chat_id ].append (partial_conversation )
268+ valid_partial_qas = [
269+ partial_qas for partial_qas in partial_question_answers if partial_qas is not None
270+ ]
271+ grouped_partial_questions = _group_partial_messages (
272+ [partial_qs_a .partial_questions for partial_qs_a in valid_partial_qas ]
273+ )
213274
214- # Sort by timestamp
215- sorted_convers = {
216- chat_id : sorted (conversations , key = lambda x : x .request_timestamp )
217- for chat_id , conversations in convers .items ()
218- }
219275 # Create the conversation objects
220276 conversations = []
221- for chat_id , sorted_convers in sorted_convers . items () :
277+ for group in grouped_partial_questions :
222278 questions_answers = []
223- first_partial_conversation = None
224- for partial_conversation in sorted_convers :
279+ first_partial_qa = None
280+ for partial_question in sorted (group , key = lambda x : x .timestamp ):
281+ # Partial questions don't contain the answer, so we need to find the corresponding
282+ selected_partial_qa = None
283+ for partial_qa in valid_partial_qas :
284+ if partial_question .message_id == partial_qa .partial_questions .message_id :
285+ selected_partial_qa = partial_qa
286+ break
287+
225288 # check if we have an answer, otherwise do not add it
226- if partial_conversation .question_answer .answer is not None :
227- first_partial_conversation = partial_conversation
228- partial_conversation .question_answer .question .message = parse_question_answer (
229- partial_conversation .question_answer .question .message
289+ if selected_partial_qa .answer is not None :
290+ # if we don't have a first question, set it
291+ first_partial_qa = first_partial_qa or selected_partial_qa
292+ question_answer = _get_question_answer_from_partial (selected_partial_qa )
293+ question_answer .question .message = parse_question_answer (
294+ question_answer .question .message
230295 )
231- questions_answers .append (partial_conversation . question_answer )
296+ questions_answers .append (question_answer )
232297
233298 # only add conversation if we have some answers
234- if len (questions_answers ) > 0 and first_partial_conversation is not None :
299+ if len (questions_answers ) > 0 and first_partial_qa is not None :
235300 conversations .append (
236301 Conversation (
237302 question_answers = questions_answers ,
238- provider = first_partial_conversation .provider ,
239- type = first_partial_conversation .type ,
240- chat_id = chat_id ,
241- conversation_timestamp = sorted_convers [ 0 ]. request_timestamp ,
303+ provider = first_partial_qa . partial_questions .provider ,
304+ type = first_partial_qa . partial_questions .type ,
305+ chat_id = first_partial_qa . partial_questions . message_id ,
306+ conversation_timestamp = first_partial_qa . partial_questions . timestamp ,
242307 )
243308 )
244309
@@ -254,10 +319,10 @@ async def parse_messages_in_conversations(
254319
255320 # Parse the prompts and outputs in parallel
256321 async with asyncio .TaskGroup () as tg :
257- tasks = [tg .create_task (parse_get_prompt_with_output (row )) for row in prompts_outputs ]
258- partial_conversations = [task .result () for task in tasks ]
322+ tasks = [tg .create_task (_get_question_answer (row )) for row in prompts_outputs ]
323+ partial_question_answers = [task .result () for task in tasks ]
259324
260- conversations = await match_conversations (partial_conversations )
325+ conversations = await match_conversations (partial_question_answers )
261326 return conversations
262327
263328
@@ -269,15 +334,17 @@ async def parse_row_alert_conversation(
269334
270335 The row contains the raw request and output strings from the pipeline.
271336 """
272- question_answer , chat_id = await _get_question_answer (row )
273- if not question_answer or not chat_id :
337+ partial_qa = await _get_question_answer (row )
338+ if not partial_qa :
274339 return None
275340
341+ question_answer = _get_question_answer_from_partial (partial_qa )
342+
276343 conversation = Conversation (
277344 question_answers = [question_answer ],
278345 provider = row .provider ,
279346 type = row .type ,
280- chat_id = chat_id or "chat-id-not-found" ,
347+ chat_id = row . id ,
281348 conversation_timestamp = row .timestamp ,
282349 )
283350 code_snippet = json .loads (row .code_snippet ) if row .code_snippet else None
0 commit comments