3434TConfigModel = TypeVar ("TConfigModel" , bound = BaseModel )
3535
3636
37- class MemoryUser :
37+ class MemoryService :
3838 def __init__ (
3939 self ,
4040 * ,
41- llm_config : dict [str , Any ] | LLMConfig | None = None ,
42- blob_config : dict [str , Any ] | BlobConfig | None = None ,
43- database_config : dict [str , Any ] | DatabaseConfig | None = None ,
44- memorize_config : dict [str , Any ] | MemorizeConfig | None = None ,
41+ llm_config : LLMConfig | dict [str , Any ] | None = None ,
42+ blob_config : BlobConfig | dict [str , Any ] | None = None ,
43+ database_config : DatabaseConfig | dict [str , Any ] | None = None ,
44+ memorize_config : MemorizeConfig | dict [str , Any ] | None = None ,
45+ retrieve_config : RetrieveConfig | dict [str , Any ] | None = None ,
4546 ):
4647 self .llm_config = self ._validate_config (llm_config , LLMConfig )
4748 self .blob_config = self ._validate_config (blob_config , BlobConfig )
4849 self .database_config = self ._validate_config (database_config , DatabaseConfig )
4950 self .memorize_config = self ._validate_config (memorize_config , MemorizeConfig )
51+ self .retrieve_config = self ._validate_config (retrieve_config , RetrieveConfig )
5052 self .fs = LocalFS (self .blob_config .resources_dir )
5153 self .store = InMemoryStore ()
5254 backend = self .llm_config .client_backend
@@ -788,19 +790,13 @@ async def retrieve(
788790 self ,
789791 query : str ,
790792 * ,
791- retrieve_config : dict [str , Any ] | RetrieveConfig | None = None ,
792793 conversation_history : list [dict [str , str ]] | None = None ,
793794 ) -> dict [str , Any ]:
794795 """
795796 Retrieve relevant memories based on the query using either RAG-based or LLM-based search.
796797
797798 Args:
798799 query: The search query string
799- retrieve_config: Configuration for retrieval method and parameters.
800- Can be a dict or RetrieveConfig object with:
801- - method: 'rag' for embedding-based vector search (default),
802- 'llm' for LLM-based semantic ranking
803- - top_k: Maximum number of results per category (default: 5)
804800 conversation_history: Optional list of last 3 conversation turns, each with 'role' and 'content'.
805801 Example: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
806802
@@ -819,14 +815,6 @@ async def retrieve(
819815 - Pre-retrieval decision checks if retrieval is needed based on query type
820816 - Query rewriting incorporates conversation history for better context
821817 """
822- # Validate and resolve config
823- config = self ._validate_config (retrieve_config , RetrieveConfig )
824-
825- # Validate method
826- if config .method not in ("rag" , "llm" ):
827- msg = f"Invalid retrieval method '{ config .method } '. Must be 'rag' or 'llm'."
828- raise ValueError (msg )
829-
830818 # Step 1: Decide if retrieval is needed
831819 needs_retrieval , rewritten_query = await self ._decide_if_retrieval_needed (query , conversation_history )
832820
@@ -844,13 +832,13 @@ async def retrieve(
844832 logger .info (f"Query rewritten: '{ query } ' -> '{ rewritten_query } '" )
845833
846834 # Step 2: Perform retrieval with rewritten query using configured method
847- if config .method == "llm" :
835+ if self . retrieve_config .method == "llm" :
848836 results = await self ._llm_based_retrieve (
849- rewritten_query , top_k = config .top_k , conversation_history = conversation_history
837+ rewritten_query , top_k = self . retrieve_config .top_k , conversation_history = conversation_history
850838 )
851839 else : # rag
852840 results = await self ._embedding_based_retrieve (
853- rewritten_query , top_k = config .top_k , conversation_history = conversation_history
841+ rewritten_query , top_k = self . retrieve_config .top_k , conversation_history = conversation_history
854842 )
855843
856844 # Add metadata
0 commit comments