@@ -174,6 +174,87 @@ def __init__(self, issue_url: str, ai_handler, args: list = None):
174174 else :
175175 get_logger ().info ('No new issues to update' )
176176
177+ elif get_settings ().pr_similar_issue .vectordb == "qdrant" :
178+ try :
179+ import qdrant_client
180+ from qdrant_client .models import (Distance , FieldCondition ,
181+ Filter , MatchValue ,
182+ PointStruct , VectorParams )
183+ except Exception :
184+ raise Exception ("Please install qdrant-client to use qdrant as vectordb" )
185+
186+ api_key = None
187+ url = None
188+ try :
189+ api_key = get_settings ().qdrant .api_key
190+ url = get_settings ().qdrant .url
191+ except Exception :
192+ if not self .cli_mode :
193+ repo_name , original_issue_number = self .git_provider ._parse_issue_url (self .issue_url .split ('=' )[- 1 ])
194+ issue_main = self .git_provider .repo_obj .get_issue (original_issue_number )
195+ issue_main .create_comment ("Please set qdrant url and api key in secrets file" )
196+ raise Exception ("Please set qdrant url and api key in secrets file" )
197+
198+ self .qdrant = qdrant_client .QdrantClient (url = url , api_key = api_key )
199+
200+ run_from_scratch = False
201+ ingest = True
202+
203+ if not self .qdrant .collection_exists (collection_name = self .index_name ):
204+ run_from_scratch = True
205+ ingest = False
206+ self .qdrant .create_collection (
207+ collection_name = self .index_name ,
208+ vectors_config = VectorParams (size = 1536 , distance = Distance .COSINE ),
209+ )
210+ else :
211+ if get_settings ().pr_similar_issue .force_update_dataset :
212+ ingest = True
213+ else :
214+ response = self .qdrant .count (
215+ collection_name = self .index_name ,
216+ count_filter = Filter (must = [
217+ FieldCondition (key = "metadata.repo" , match = MatchValue (value = repo_name_for_index )),
218+ FieldCondition (key = "id" , match = MatchValue (value = f"example_issue_{ repo_name_for_index } " )),
219+ ]),
220+ )
221+ ingest = True if response .count == 0 else False
222+
223+ if run_from_scratch or ingest :
224+ get_logger ().info ('Indexing the entire repo...' )
225+ get_logger ().info ('Getting issues...' )
226+ issues = list (repo_obj .get_issues (state = 'all' ))
227+ get_logger ().info ('Done' )
228+ self ._update_qdrant_with_issues (issues , repo_name_for_index , ingest = ingest )
229+ else :
230+ issues_to_update = []
231+ issues_paginated_list = repo_obj .get_issues (state = 'all' )
232+ counter = 1
233+ for issue in issues_paginated_list :
234+ if issue .pull_request :
235+ continue
236+ issue_str , comments , number = self ._process_issue (issue )
237+ issue_key = f"issue_{ number } "
238+ point_id = issue_key + "." + "issue"
239+ response = self .qdrant .count (
240+ collection_name = self .index_name ,
241+ count_filter = Filter (must = [
242+ FieldCondition (key = "id" , match = MatchValue (value = point_id )),
243+ FieldCondition (key = "metadata.repo" , match = MatchValue (value = repo_name_for_index )),
244+ ]),
245+ )
246+ if response .count == 0 :
247+ counter += 1
248+ issues_to_update .append (issue )
249+ else :
250+ break
251+
252+ if issues_to_update :
253+ get_logger ().info (f'Updating index with { counter } new issues...' )
254+ self ._update_qdrant_with_issues (issues_to_update , repo_name_for_index , ingest = True )
255+ else :
256+ get_logger ().info ('No new issues to update' )
257+
177258
178259 async def run (self ):
179260 get_logger ().info ('Getting issue...' )
@@ -246,6 +327,36 @@ async def run(self):
246327 score_list .append (str ("{:.2f}" .format (1 - r ['_distance' ])))
247328 get_logger ().info ('Done' )
248329
330+ elif get_settings ().pr_similar_issue .vectordb == "qdrant" :
331+ from qdrant_client .models import FieldCondition , Filter , MatchValue
332+ res = self .qdrant .search (
333+ collection_name = self .index_name ,
334+ query_vector = embeds [0 ],
335+ limit = 5 ,
336+ query_filter = Filter (must = [FieldCondition (key = "metadata.repo" , match = MatchValue (value = self .repo_name_for_index ))]),
337+ with_payload = True ,
338+ )
339+
340+ for r in res :
341+ rid = r .payload .get ("id" , "" )
342+ if 'example_issue_' in rid :
343+ continue
344+ try :
345+ issue_number = int (rid .split ('.' )[0 ].split ('_' )[- 1 ])
346+ except Exception :
347+ get_logger ().debug (f"Failed to parse issue number from { rid } " )
348+ continue
349+ if original_issue_number == issue_number :
350+ continue
351+ if issue_number not in relevant_issues_number_list :
352+ relevant_issues_number_list .append (issue_number )
353+ if 'comment' in rid :
354+ relevant_comment_number_list .append (int (rid .split ('.' )[1 ].split ('_' )[- 1 ]))
355+ else :
356+ relevant_comment_number_list .append (- 1 )
357+ score_list .append (str ("{:.2f}" .format (r .score )))
358+ get_logger ().info ('Done' )
359+
249360 get_logger ().info ('Publishing response...' )
250361 similar_issues_str = "### Similar Issues\n ___\n \n "
251362
@@ -458,6 +569,101 @@ def _update_table_with_issues(self, issues_list, repo_name_for_index, ingest=Fal
458569 get_logger ().info ('Done' )
459570
460571
572+ def _update_qdrant_with_issues (self , issues_list , repo_name_for_index , ingest = False ):
573+ try :
574+ import uuid
575+
576+ import pandas as pd
577+ from qdrant_client .models import PointStruct
578+ except Exception :
579+ raise
580+
581+ get_logger ().info ('Processing issues...' )
582+ corpus = Corpus ()
583+ example_issue_record = Record (
584+ id = f"example_issue_{ repo_name_for_index } " ,
585+ text = "example_issue" ,
586+ metadata = Metadata (repo = repo_name_for_index )
587+ )
588+ corpus .append (example_issue_record )
589+
590+ counter = 0
591+ for issue in issues_list :
592+ if issue .pull_request :
593+ continue
594+
595+ counter += 1
596+ if counter % 100 == 0 :
597+ get_logger ().info (f"Scanned { counter } issues" )
598+ if counter >= self .max_issues_to_scan :
599+ get_logger ().info (f"Scanned { self .max_issues_to_scan } issues, stopping" )
600+ break
601+
602+ issue_str , comments , number = self ._process_issue (issue )
603+ issue_key = f"issue_{ number } "
604+ username = issue .user .login
605+ created_at = str (issue .created_at )
606+ if len (issue_str ) < 8000 or \
607+ self .token_handler .count_tokens (issue_str ) < get_max_tokens (MODEL ):
608+ issue_record = Record (
609+ id = issue_key + "." + "issue" ,
610+ text = issue_str ,
611+ metadata = Metadata (repo = repo_name_for_index ,
612+ username = username ,
613+ created_at = created_at ,
614+ level = IssueLevel .ISSUE )
615+ )
616+ corpus .append (issue_record )
617+ if comments :
618+ for j , comment in enumerate (comments ):
619+ comment_body = comment .body
620+ num_words_comment = len (comment_body .split ())
621+ if num_words_comment < 10 or not isinstance (comment_body , str ):
622+ continue
623+
624+ if len (comment_body ) < 8000 or \
625+ self .token_handler .count_tokens (comment_body ) < MAX_TOKENS [MODEL ]:
626+ comment_record = Record (
627+ id = issue_key + ".comment_" + str (j + 1 ),
628+ text = comment_body ,
629+ metadata = Metadata (repo = repo_name_for_index ,
630+ username = username ,
631+ created_at = created_at ,
632+ level = IssueLevel .COMMENT )
633+ )
634+ corpus .append (comment_record )
635+
636+ df = pd .DataFrame (corpus .dict ()["documents" ])
637+ get_logger ().info ('Done' )
638+
639+ get_logger ().info ('Embedding...' )
640+ openai .api_key = get_settings ().openai .key
641+ list_to_encode = list (df ["text" ].values )
642+ try :
643+ res = openai .Embedding .create (input = list_to_encode , engine = MODEL )
644+ embeds = [record ['embedding' ] for record in res ['data' ]]
645+ except Exception :
646+ embeds = []
647+ get_logger ().error ('Failed to embed entire list, embedding one by one...' )
648+ for i , text in enumerate (list_to_encode ):
649+ try :
650+ res = openai .Embedding .create (input = [text ], engine = MODEL )
651+ embeds .append (res ['data' ][0 ]['embedding' ])
652+ except Exception :
653+ embeds .append ([0 ] * 1536 )
654+ df ["vector" ] = embeds
655+ get_logger ().info ('Done' )
656+
657+ get_logger ().info ('Upserting into Qdrant...' )
658+ points = []
659+ for row in df .to_dict (orient = "records" ):
660+ points .append (
661+ PointStruct (id = uuid .uuid5 (uuid .NAMESPACE_DNS , row ["id" ]).hex , vector = row ["vector" ], payload = {"id" : row ["id" ], "text" : row ["text" ], "metadata" : row ["metadata" ]})
662+ )
663+ self .qdrant .upsert (collection_name = self .index_name , points = points )
664+ get_logger ().info ('Done' )
665+
666+
461667class IssueLevel (str , Enum ):
462668 ISSUE = "issue"
463669 COMMENT = "comment"
0 commit comments