1717from memu .prompts .memory_type import PROMPTS as MEMORY_TYPE_PROMPTS
1818from memu .prompts .preprocess import PROMPTS as PREPROCESS_PROMPTS
1919from memu .prompts .retrieve .judger import PROMPT as RETRIEVE_JUDGER_PROMPT
20+ from memu .prompts .retrieve .query_rewriter import PROMPT as QUERY_REWRITER_PROMPT
2021from memu .storage .local_fs import LocalFS
2122from memu .utils .video import VideoFrameExtractor
2223from memu .vector .index import cosine_topk
@@ -777,9 +778,56 @@ def _validate_config(
777778 return model_type ()
778779 return model_type .model_validate (config )
779780
780- async def retrieve (self , query : str , * , top_k : int = 5 ) -> dict [str , Any ]:
781+ async def retrieve (
782+ self ,
783+ query : str ,
784+ * ,
785+ conversation_history : list [dict [str , str ]] | None = None ,
786+ method : str = "rag" ,
787+ top_k : int = 5 ,
788+ ) -> dict [str , Any ]:
789+ """
790+ Retrieve relevant memories based on the query.
791+
792+ Args:
793+ query: The search query
794+ conversation_history: Optional conversation history for query rewriting
795+ method: Retrieval method - "rag" (vector similarity) or "llm" (LLM-based ranking)
796+ top_k: Number of top results to return
797+
798+ Returns:
799+ Dictionary containing original_query, rewritten_query, method, and retrieved results
800+ """
801+ # Rewrite query if conversation history is provided
802+ original_query = query
803+ rewritten_query = query
804+
805+ if conversation_history :
806+ rewritten_query = await self ._rewrite_query_with_history (query , conversation_history )
807+ logger .debug (f"Original query: { original_query } " )
808+ logger .debug (f"Rewritten query: { rewritten_query } " )
809+
810+ response : dict [str , Any ] = {
811+ "original_query" : original_query ,
812+ "rewritten_query" : rewritten_query ,
813+ "method" : method ,
814+ "resources" : [],
815+ "items" : [],
816+ "categories" : [],
817+ }
818+
819+ if method == "rag" :
820+ return await self ._retrieve_rag (rewritten_query , response , top_k )
821+ elif method == "llm" :
822+ return await self ._retrieve_llm (rewritten_query , response , top_k )
823+ else :
824+ msg = f"Unknown retrieval method '{ method } '. Use 'rag' or 'llm'."
825+ raise ValueError (msg )
826+
827+ async def _retrieve_rag (self , query : str , response : dict [str , Any ], top_k : int ) -> dict [str , Any ]:
828+ """RAG-based retrieval using vector similarity search"""
829+ # Use query for embedding
781830 qvec = (await self .openai .embed ([query ]))[0 ]
782- response : dict [str , list [dict [str , Any ]]] = {"resources" : [], "items" : [], "categories" : []}
783831 content_sections : list [str ] = []
784832
785833 cat_hits , summary_lookup = await self ._rank_categories_by_summary (qvec , top_k )
@@ -806,6 +854,126 @@ async def retrieve(self, query: str, *, top_k: int = 5) -> dict[str, Any]:
806854
807855 return response
808856
857+ async def _retrieve_llm (self , query : str , response : dict [str , Any ], top_k : int ) -> dict [str , Any ]:
858+ """LLM-based retrieval using language model to rank and select memories"""
859+ # Get all available memories
860+ all_categories = list (self .store .categories .values ())
861+ all_items = list (self .store .items .values ())
862+ all_resources = list (self .store .resources .values ())
863+
864+ # Use LLM to select and rank relevant memories
865+ if all_categories :
866+ selected_categories = await self ._llm_rank_memories (query , all_categories , "categories" , top_k )
867+ response ["categories" ] = selected_categories
868+
869+ if all_items :
870+ selected_items = await self ._llm_rank_memories (query , all_items , "items" , top_k )
871+ response ["items" ] = selected_items
872+
873+ if all_resources :
874+ selected_resources = await self ._llm_rank_memories (query , all_resources , "resources" , top_k )
875+ response ["resources" ] = selected_resources
876+
877+ return response
878+
879+ async def _llm_rank_memories (
880+ self , query : str , memories : list [Any ], memory_type : str , top_k : int
881+ ) -> list [dict [str , Any ]]:
882+ """Use LLM to rank and select relevant memories"""
883+ if not memories :
884+ return []
885+
886+ # Limit to top 20 to avoid token limits
887+ sample_size = min (len (memories ), 20 )
888+ memories_to_rank = memories [:sample_size ]
889+
890+ # Format memories for LLM
891+ formatted_memories = []
892+ for idx , mem in enumerate (memories_to_rank ):
893+ if memory_type == "categories" :
894+ content = f"Category: { mem .name } \n Summary: { mem .summary or 'N/A' } "
895+ elif memory_type == "items" :
896+ content = f"Item: { mem .summary } "
897+ else : # resources
898+ content = f"Resource: { mem .caption or mem .url } "
899+ formatted_memories .append (f"[{ idx } ] { content } " )
900+
901+ memories_text = "\n \n " .join (formatted_memories )
902+
903+ # Create prompt for LLM ranking
904+ prompt = f"""Given the query and a list of memories, select the top { top_k } most relevant memories.
905+ Return only the indices (numbers) of the selected memories, separated by commas.
906+
907+ Query: { query }
908+
909+ Memories:
910+ { memories_text }
911+
912+ Output format: 0,3,7,... (indices only, comma-separated)
913+ Selected indices:"""
914+
915+ response_text = await self .openai .summarize (prompt , system_prompt = None )
916+
917+ # Parse selected indices
918+ selected_indices = self ._parse_llm_indices (response_text , len (memories_to_rank ))
919+
920+ # Return selected memories
921+ result = []
922+ for idx in selected_indices [:top_k ]:
923+ mem = memories_to_rank [idx ]
924+ mem_dict = {
925+ "id" : mem .id ,
926+ "score" : 1.0 - (selected_indices .index (idx ) * 0.1 ), # Decreasing score
927+ }
928+ if memory_type == "categories" :
929+ mem_dict .update ({"name" : mem .name , "summary" : mem .summary })
930+ elif memory_type == "items" :
931+ mem_dict .update ({"summary" : mem .summary , "memory_type" : mem .memory_type })
932+ else :
933+ mem_dict .update ({"url" : mem .url , "caption" : mem .caption })
934+ result .append (mem_dict )
935+
936+ return result
937+
938+ def _parse_llm_indices (self , response : str , max_idx : int ) -> list [int ]:
939+ """Parse indices from LLM response"""
940+ # Extract numbers from response
941+ numbers = re .findall (r"\d+" , response )
942+ indices = []
943+ for num_str in numbers :
944+ idx = int (num_str )
945+ if 0 <= idx < max_idx and idx not in indices :
946+ indices .append (idx )
947+ return indices
948+
949+ async def _rewrite_query_with_history (self , query : str , conversation_history : list [dict [str , str ]]) -> str :
950+ """Rewrite query using conversation history to resolve references"""
951+ # Format conversation history
952+ history_text = "\n " .join ([
953+ f"{ msg .get ('role' , 'unknown' )} : { msg .get ('content' , '' )} " for msg in conversation_history
954+ ])
955+
956+ # Create prompt for query rewriting
957+ prompt = QUERY_REWRITER_PROMPT .format (
958+ conversation_history = self ._escape_prompt_value (history_text ), query = self ._escape_prompt_value (query )
959+ )
960+
961+ # Get rewritten query from LLM
962+ response = await self .openai .summarize (prompt , system_prompt = None )
963+
964+ # Parse the rewritten query from the response
965+ rewritten_query = self ._parse_rewritten_query (response )
966+ return rewritten_query or query # Fall back to original if parsing fails
967+
968+ def _parse_rewritten_query (self , response : str ) -> str | None :
969+ """Parse rewritten query from LLM response"""
970+ # Try to extract content between <rewritten_query> tags
971+ match = re .search (r"<rewritten_query>\s*(.*?)\s*</rewritten_query>" , response , re .DOTALL )
972+ if match :
973+ return match .group (1 ).strip ()
974+ # If no tags found, return the response as is (fallback)
975+ return response .strip ()
976+
809977 async def _rank_categories_by_summary (
810978 self , query_vec : list [float ], top_k : int
811979 ) -> tuple [list [tuple [str , float ]], dict [str , str ]]:
0 commit comments