Skip to content

Commit b2c4fb8

Browse files
authored
feat: expose configs (#80)
1 parent b154e02 commit b2c4fb8

File tree

3 files changed

+69
-38
lines changed

3 files changed

+69
-38
lines changed

src/memu/app/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from memu.app.service import MemoryUser
2-
from memu.app.settings import AppSettings
2+
from memu.app.settings import BlobConfig, DatabaseConfig, LLMConfig, MemorizeConfig
33

4-
__all__ = ["AppSettings", "MemoryUser"]
4+
__all__ = ["BlobConfig", "DatabaseConfig", "LLMConfig", "MemorizeConfig", "MemoryUser"]

src/memu/app/service.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import logging
44
import pathlib
55
import re
6-
from collections.abc import Sequence
7-
from typing import Any, cast
6+
from collections.abc import Mapping, Sequence
7+
from typing import Any, TypeVar, cast
88

99
from pydantic import BaseModel
1010

11-
from memu.app.settings import AppSettings
11+
from memu.app.settings import BlobConfig, DatabaseConfig, LLMConfig, MemorizeConfig
1212
from memu.llm.http_client import HTTPLLMClient
1313
from memu.memory.repo import InMemoryStore
1414
from memu.models import CategoryItem, MemoryCategory, MemoryItem, MemoryType, Resource
@@ -24,34 +24,47 @@
2424
logger = logging.getLogger(__name__)
2525

2626

27+
TConfigModel = TypeVar("TConfigModel", bound=BaseModel)
28+
29+
2730
class MemoryUser:
28-
def __init__(self, settings: AppSettings):
29-
self.settings = settings
30-
self.fs = LocalFS(settings.resources_dir)
31+
def __init__(
32+
self,
33+
*,
34+
llm_config: dict[str, Any] | LLMConfig | None = None,
35+
blob_config: dict[str, Any] | BlobConfig | None = None,
36+
database_config: dict[str, Any] | DatabaseConfig | None = None,
37+
memorize_config: dict[str, Any] | MemorizeConfig | None = None,
38+
):
39+
self.llm_config = self._validate_config(llm_config, LLMConfig)
40+
self.blob_config = self._validate_config(blob_config, BlobConfig)
41+
self.database_config = self._validate_config(database_config, DatabaseConfig)
42+
self.memorize_config = self._validate_config(memorize_config, MemorizeConfig)
43+
self.fs = LocalFS(self.blob_config.resources_dir)
3144
self.store = InMemoryStore()
32-
backend = (settings.llm_client_backend or "httpx").lower()
45+
backend = self.llm_config.client_backend
3346
self.openai: Any
3447
client_kwargs: dict[str, Any] = {
35-
"base_url": settings.openai_base,
36-
"api_key": settings.openai_api_key,
37-
"chat_model": settings.chat_model,
38-
"embed_model": settings.embed_model,
48+
"base_url": self.llm_config.base_url,
49+
"api_key": self.llm_config.api_key,
50+
"chat_model": self.llm_config.chat_model,
51+
"embed_model": self.llm_config.embed_model,
3952
}
4053
if backend == "sdk":
4154
from memu.llm.openai_sdk import OpenAISDKClient
4255

4356
self.openai = OpenAISDKClient(**client_kwargs)
4457
elif backend == "httpx":
4558
self.openai = HTTPLLMClient(
46-
provider=self.settings.llm_http_provider,
47-
endpoint_overrides=self.settings.llm_http_endpoints,
59+
provider=self.llm_config.provider,
60+
endpoint_overrides=self.llm_config.endpoint_overrides,
4861
**client_kwargs,
4962
)
5063
else:
51-
msg = f"Unknown llm_client_backend '{settings.llm_client_backend}'"
64+
msg = f"Unknown llm_client_backend '{self.llm_config.client_backend}'"
5265
raise ValueError(msg)
5366

54-
self.category_configs: list[dict[str, str]] = list(settings.memory_categories or [])
67+
self.category_configs: list[dict[str, str]] = list(self.memorize_config.memory_categories or [])
5568
self._category_prompt_str = self._format_categories_for_prompt(self.category_configs)
5669
self._category_ids: list[str] = []
5770
self._category_name_to_id: dict[str, str] = {}
@@ -121,11 +134,12 @@ async def _create_resource_with_caption(
121134
return res
122135

123136
def _resolve_memory_types(self) -> list[MemoryType]:
124-
configured_types = self.settings.memory_types or DEFAULT_MEMORY_TYPES
137+
configured_types = self.memorize_config.memory_types or DEFAULT_MEMORY_TYPES
125138
return [cast(MemoryType, mtype) for mtype in configured_types]
126139

127140
def _resolve_summary_prompt(self, modality: str, override: str | None) -> str:
128-
return override or self.settings.summary_prompts.get(modality) or self.settings.default_summary_prompt
141+
memo_settings = self.memorize_config
142+
return override or memo_settings.summary_prompts.get(modality) or memo_settings.default_summary_prompt
129143

130144
async def _generate_structured_entries(
131145
self,
@@ -580,7 +594,7 @@ def _add_conversation_indices(self, conversation: str) -> str:
580594

581595
def _build_memory_type_prompt(self, *, memory_type: MemoryType, resource_text: str, categories_str: str) -> str:
582596
template = (
583-
self.settings.memory_type_prompts.get(memory_type) or MEMORY_TYPE_PROMPTS.get(memory_type) or ""
597+
self.memorize_config.memory_type_prompts.get(memory_type) or MEMORY_TYPE_PROMPTS.get(memory_type) or ""
584598
).strip()
585599
if not template:
586600
return resource_text
@@ -596,7 +610,7 @@ def _build_category_summary_prompt(self, *, category: MemoryCategory, new_memori
596610
category=self._escape_prompt_value(category.name),
597611
original_content=self._escape_prompt_value(original or ""),
598612
new_memory_items_text=self._escape_prompt_value(new_items_text or "No new memory items."),
599-
target_length=self.settings.category_summary_target_length,
613+
target_length=self.memorize_config.category_summary_target_length,
600614
)
601615

602616
async def _update_category_summaries(self, updates: dict[str, list[str]]) -> None:
@@ -752,6 +766,17 @@ def _model_dump_without_embeddings(self, obj: BaseModel) -> dict[str, Any]:
752766
data.pop("embedding", None)
753767
return data
754768

769+
@staticmethod
770+
def _validate_config(
771+
config: Mapping[str, Any] | BaseModel | None,
772+
model_type: type[TConfigModel],
773+
) -> TConfigModel:
774+
if isinstance(config, model_type):
775+
return config
776+
if config is None:
777+
return model_type()
778+
return model_type.model_validate(config)
779+
755780
async def retrieve(self, query: str, *, top_k: int = 5) -> dict[str, Any]:
756781
qvec = (await self.openai.embed([query]))[0]
757782
response: dict[str, list[dict[str, Any]]] = {"resources": [], "items": [], "categories": []}

src/memu/app/settings.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,36 @@ def _default_memory_categories() -> list[dict[str, str]]:
2727
]
2828

2929

30-
class AppSettings(BaseModel):
31-
# where to store raw resources
32-
resources_dir: str = Field(default="./resources")
33-
# openai base
34-
openai_base: str = Field(default="https://api.openai.com/v1")
35-
openai_api_key: str = Field(default="OPENAI_API_KEY")
36-
# models
37-
chat_model: str = Field(default="gpt-5-nano")
30+
class LLMConfig(BaseModel):
31+
provider: str = Field(
32+
default="openai",
33+
description="Identifier for the LLM provider implementation (used by HTTP client backend).",
34+
)
35+
base_url: str = Field(default="https://api.openai.com/v1")
36+
api_key: str = Field(default="OPENAI_API_KEY")
37+
chat_model: str = Field(default="gpt-4o-mini")
3838
embed_model: str = Field(default="text-embedding-3-small")
39-
llm_client_backend: str = Field(
39+
client_backend: str = Field(
4040
default="sdk",
41-
description="Which OpenAI client backend to use: 'httpx' (httpx) or 'sdk' (official OpenAI).",
41+
description="Which LLM client backend to use: 'httpx' (httpx) or 'sdk' (official OpenAI).",
4242
)
43-
llm_http_provider: str = Field(
44-
default="openai",
45-
description="Name of the HTTP LLM provider implementation (e.g. 'openai').",
46-
)
47-
llm_http_endpoints: dict[str, str] = Field(
43+
endpoint_overrides: dict[str, str] = Field(
4844
default_factory=dict,
4945
description="Optional overrides for HTTP endpoints (keys: 'chat'/'summary', 'embeddings'/'embed').",
5046
)
51-
# thresholds
47+
48+
49+
class BlobConfig(BaseModel):
50+
provider: str = Field(default="local")
51+
resources_dir: str = Field(default="./data/resources")
52+
53+
54+
class DatabaseConfig(BaseModel):
55+
provider: str = Field(default="memory")
56+
57+
58+
class MemorizeConfig(BaseModel):
5259
category_assign_threshold: float = Field(default=0.25)
53-
# summarization prompts
5460
default_summary_prompt: str = Field(default="Summarize the text in one short paragraph.")
5561
summary_prompts: dict[str, str] = Field(
5662
default_factory=dict,

0 commit comments

Comments
 (0)