33import logging
44import pathlib
55import re
6- from collections .abc import Sequence
7- from typing import Any , cast
6+ from collections .abc import Mapping , Sequence
7+ from typing import Any , TypeVar , cast
88
99from pydantic import BaseModel
1010
11- from memu .app .settings import AppSettings
11+ from memu .app .settings import BlobConfig , DatabaseConfig , LLMConfig , MemorizeConfig
1212from memu .llm .http_client import HTTPLLMClient
1313from memu .memory .repo import InMemoryStore
1414from memu .models import CategoryItem , MemoryCategory , MemoryItem , MemoryType , Resource
2424logger = logging .getLogger (__name__ )
2525
2626
27+ TConfigModel = TypeVar ("TConfigModel" , bound = BaseModel )
28+
29+
2730class MemoryUser :
28- def __init__ (self , settings : AppSettings ):
29- self .settings = settings
30- self .fs = LocalFS (settings .resources_dir )
31+ def __init__ (
32+ self ,
33+ * ,
34+ llm_config : dict [str , Any ] | LLMConfig | None = None ,
35+ blob_config : dict [str , Any ] | BlobConfig | None = None ,
36+ database_config : dict [str , Any ] | DatabaseConfig | None = None ,
37+ memorize_config : dict [str , Any ] | MemorizeConfig | None = None ,
38+ ):
39+ self .llm_config = self ._validate_config (llm_config , LLMConfig )
40+ self .blob_config = self ._validate_config (blob_config , BlobConfig )
41+ self .database_config = self ._validate_config (database_config , DatabaseConfig )
42+ self .memorize_config = self ._validate_config (memorize_config , MemorizeConfig )
43+ self .fs = LocalFS (self .blob_config .resources_dir )
3144 self .store = InMemoryStore ()
32- backend = ( settings . llm_client_backend or "httpx" ). lower ()
45+ backend = self . llm_config . client_backend
3346 self .openai : Any
3447 client_kwargs : dict [str , Any ] = {
35- "base_url" : settings . openai_base ,
36- "api_key" : settings . openai_api_key ,
37- "chat_model" : settings .chat_model ,
38- "embed_model" : settings .embed_model ,
48+ "base_url" : self . llm_config . base_url ,
49+ "api_key" : self . llm_config . api_key ,
50+ "chat_model" : self . llm_config .chat_model ,
51+ "embed_model" : self . llm_config .embed_model ,
3952 }
4053 if backend == "sdk" :
4154 from memu .llm .openai_sdk import OpenAISDKClient
4255
4356 self .openai = OpenAISDKClient (** client_kwargs )
4457 elif backend == "httpx" :
4558 self .openai = HTTPLLMClient (
46- provider = self .settings . llm_http_provider ,
47- endpoint_overrides = self .settings . llm_http_endpoints ,
59+ provider = self .llm_config . provider ,
60+ endpoint_overrides = self .llm_config . endpoint_overrides ,
4861 ** client_kwargs ,
4962 )
5063 else :
51- msg = f"Unknown llm_client_backend '{ settings . llm_client_backend } '"
64+ msg = f"Unknown llm_client_backend '{ self . llm_config . client_backend } '"
5265 raise ValueError (msg )
5366
54- self .category_configs : list [dict [str , str ]] = list (settings .memory_categories or [])
67+ self .category_configs : list [dict [str , str ]] = list (self . memorize_config .memory_categories or [])
5568 self ._category_prompt_str = self ._format_categories_for_prompt (self .category_configs )
5669 self ._category_ids : list [str ] = []
5770 self ._category_name_to_id : dict [str , str ] = {}
@@ -121,11 +134,12 @@ async def _create_resource_with_caption(
121134 return res
122135
123136 def _resolve_memory_types (self ) -> list [MemoryType ]:
124- configured_types = self .settings .memory_types or DEFAULT_MEMORY_TYPES
137+ configured_types = self .memorize_config .memory_types or DEFAULT_MEMORY_TYPES
125138 return [cast (MemoryType , mtype ) for mtype in configured_types ]
126139
127140 def _resolve_summary_prompt (self , modality : str , override : str | None ) -> str :
128- return override or self .settings .summary_prompts .get (modality ) or self .settings .default_summary_prompt
141+ memo_settings = self .memorize_config
142+ return override or memo_settings .summary_prompts .get (modality ) or memo_settings .default_summary_prompt
129143
130144 async def _generate_structured_entries (
131145 self ,
@@ -580,7 +594,7 @@ def _add_conversation_indices(self, conversation: str) -> str:
580594
581595 def _build_memory_type_prompt (self , * , memory_type : MemoryType , resource_text : str , categories_str : str ) -> str :
582596 template = (
583- self .settings .memory_type_prompts .get (memory_type ) or MEMORY_TYPE_PROMPTS .get (memory_type ) or ""
597+ self .memorize_config .memory_type_prompts .get (memory_type ) or MEMORY_TYPE_PROMPTS .get (memory_type ) or ""
584598 ).strip ()
585599 if not template :
586600 return resource_text
@@ -596,7 +610,7 @@ def _build_category_summary_prompt(self, *, category: MemoryCategory, new_memori
596610 category = self ._escape_prompt_value (category .name ),
597611 original_content = self ._escape_prompt_value (original or "" ),
598612 new_memory_items_text = self ._escape_prompt_value (new_items_text or "No new memory items." ),
599- target_length = self .settings .category_summary_target_length ,
613+ target_length = self .memorize_config .category_summary_target_length ,
600614 )
601615
602616 async def _update_category_summaries (self , updates : dict [str , list [str ]]) -> None :
@@ -752,6 +766,17 @@ def _model_dump_without_embeddings(self, obj: BaseModel) -> dict[str, Any]:
752766 data .pop ("embedding" , None )
753767 return data
754768
769+ @staticmethod
770+ def _validate_config (
771+ config : Mapping [str , Any ] | BaseModel | None ,
772+ model_type : type [TConfigModel ],
773+ ) -> TConfigModel :
774+ if isinstance (config , model_type ):
775+ return config
776+ if config is None :
777+ return model_type ()
778+ return model_type .model_validate (config )
779+
755780 async def retrieve (self , query : str , * , top_k : int = 5 ) -> dict [str , Any ]:
756781 qvec = (await self .openai .embed ([query ]))[0 ]
757782 response : dict [str , list [dict [str , Any ]]] = {"resources" : [], "items" : [], "categories" : []}
0 commit comments