@@ -785,13 +785,14 @@ def _validate_config(
785785
786786 async def retrieve (
787787 self ,
788- queries : list [str ],
788+ queries : list [dict [ str , Any ] ],
789789 ) -> dict [str , Any ]:
790790 """
791791 Retrieve relevant memories based on the query using either RAG-based or LLM-based search.
792792
793793 Args:
794- queries: List of query strings. The last one is the current query, others are context.
794+ queries: List of query messages in format [{"role": "user", "content": {"text": "..."}}].
795+ The last one is the current query, others are context.
795796 If list has only 1 element, no query rewriting is performed.
796797
797798 Returns:
@@ -813,12 +814,13 @@ async def retrieve(
813814 if not queries :
814815 raise ValueError ("empty_queries" )
815816
816- current_query = queries [- 1 ]
817- context_queries = queries [:- 1 ] if len (queries ) > 1 else []
817+ # Extract text from the query structure
818+ current_query = self ._extract_query_text (queries [- 1 ])
819+ context_queries_objs = queries [:- 1 ] if len (queries ) > 1 else []
818820
819821 # Step 1: Decide if retrieval is needed
820822 needs_retrieval , rewritten_query = await self ._decide_if_retrieval_needed (
821- current_query , context_queries , retrieved_content = None
823+ current_query , context_queries_objs , retrieved_content = None
822824 )
823825
824826 # If only one query, do not use the rewritten version (use original)
@@ -842,11 +844,11 @@ async def retrieve(
842844 # Step 2: Perform retrieval with rewritten query using configured method
843845 if self .retrieve_config .method == "llm" :
844846 results = await self ._llm_based_retrieve (
845- rewritten_query , top_k = self .retrieve_config .top_k , context_queries = context_queries
847+ rewritten_query , top_k = self .retrieve_config .top_k , context_queries = context_queries_objs
846848 )
847849 else : # rag
848850 results = await self ._embedding_based_retrieve (
849- rewritten_query , top_k = self .retrieve_config .top_k , context_queries = context_queries
851+ rewritten_query , top_k = self .retrieve_config .top_k , context_queries = context_queries_objs
850852 )
851853
852854 # Add metadata
@@ -874,7 +876,7 @@ async def _rank_categories_by_summary(
874876 async def _decide_if_retrieval_needed (
875877 self ,
876878 query : str ,
877- context_queries : list [str ] | None ,
879+ context_queries : list [dict [ str , Any ] ] | None ,
878880 retrieved_content : str | None = None ,
879881 system_prompt : str | None = None ,
880882 ) -> tuple [bool , str ]:
@@ -883,7 +885,7 @@ async def _decide_if_retrieval_needed(
883885
884886 Args:
885887 query: The current query string
886- context_queries: List of context queries
888+ context_queries: List of previous query objects with role and content
887889 retrieved_content: Content retrieved so far (if checking for sufficiency)
888890 system_prompt: Optional system prompt override
889891
@@ -908,17 +910,61 @@ async def _decide_if_retrieval_needed(
908910
909911 return decision == "RETRIEVE" , rewritten
910912
911- def _format_query_context (self , queries : list [str ] | None ) -> str :
912- """Format query context for prompts"""
913+ def _format_query_context (self , queries : list [dict [ str , Any ] ] | None ) -> str :
914+ """Format query context for prompts, including role information """
913915 if not queries :
914916 return "No query context."
915917
916918 lines = []
917919 for q in queries :
918- lines .append (f"- { q } " )
920+ if isinstance (q , str ):
921+ # Backward compatibility
922+ lines .append (f"- { q } " )
923+ elif isinstance (q , dict ):
924+ role = q .get ("role" , "user" )
925+ content = q .get ("content" )
926+ if isinstance (content , dict ):
927+ text = content .get ("text" , "" )
928+ elif isinstance (content , str ):
929+ text = content
930+ else :
931+ text = str (content )
932+ lines .append (f"- [{ role } ]: { text } " )
933+ else :
934+ lines .append (f"- { q !s} " )
919935
920936 return "\n " .join (lines )
921937
938+ @staticmethod
939+ def _extract_query_text (query : dict [str , Any ]) -> str :
940+ """
941+ Extract text content from query message structure.
942+
943+ Args:
944+ query: Query in format {"role": "user", "content": {"text": "..."}}
945+
946+ Returns:
947+ The extracted text string
948+ """
949+ if isinstance (query , str ):
950+ # Backward compatibility: if it's already a string, return it
951+ return query
952+
953+ if not isinstance (query , dict ):
954+ raise TypeError ("INVALID" )
955+
956+ content = query .get ("content" )
957+ if isinstance (content , dict ):
958+ text = content .get ("text" , "" )
959+ if not text :
960+ raise ValueError ("EMPTY" )
961+ return str (text )
962+ elif isinstance (content , str ):
963+ # Also support {"role": "user", "content": "text"} format
964+ return content
965+ else :
966+ raise TypeError ("INVALID" )
967+
922968 def _extract_decision (self , raw : str ) -> str :
923969 """Extract RETRIEVE or NO_RETRIEVE decision from LLM response"""
924970 if not raw :
@@ -946,7 +992,7 @@ def _extract_rewritten_query(self, raw: str) -> str | None:
946992 return None
947993
948994 async def _embedding_based_retrieve (
949- self , query : str , top_k : int , context_queries : list [str ] | None
995+ self , query : str , top_k : int , context_queries : list [dict [ str , Any ] ] | None
950996 ) -> dict [str , Any ]:
951997 """Embedding-based retrieval with query rewriting and judging at each tier"""
952998 current_query = query
@@ -1056,7 +1102,9 @@ def _extract_judgement(self, raw: str) -> str:
10561102 return "ENOUGH"
10571103 return "MORE"
10581104
1059- async def _llm_based_retrieve (self , query : str , top_k : int , context_queries : list [str ] | None ) -> dict [str , Any ]:
1105+ async def _llm_based_retrieve (
1106+ self , query : str , top_k : int , context_queries : list [dict [str , Any ]] | None
1107+ ) -> dict [str , Any ]:
10601108 """
10611109 LLM-based retrieval that uses language model to search and rank results
10621110 in a hierarchical manner, with query rewriting and judging at each tier.
0 commit comments