Skip to content

Commit af580a1

Browse files
authored
refactor: use literal type for retrieve method config (#97)
1 parent 1e732f0 commit af580a1

File tree

4 files changed

+36
-31
lines changed

4 files changed

+36
-31
lines changed

src/memu/app/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from memu.app.service import MemoryUser
1+
from memu.app.service import MemoryService
22
from memu.app.settings import BlobConfig, DatabaseConfig, LLMConfig, MemorizeConfig, RetrieveConfig
33

4-
__all__ = ["BlobConfig", "DatabaseConfig", "LLMConfig", "MemorizeConfig", "MemoryUser", "RetrieveConfig"]
4+
__all__ = ["BlobConfig", "DatabaseConfig", "LLMConfig", "MemorizeConfig", "MemoryService", "RetrieveConfig"]

src/memu/app/service.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,21 @@
3434
TConfigModel = TypeVar("TConfigModel", bound=BaseModel)
3535

3636

37-
class MemoryUser:
37+
class MemoryService:
3838
def __init__(
3939
self,
4040
*,
41-
llm_config: dict[str, Any] | LLMConfig | None = None,
42-
blob_config: dict[str, Any] | BlobConfig | None = None,
43-
database_config: dict[str, Any] | DatabaseConfig | None = None,
44-
memorize_config: dict[str, Any] | MemorizeConfig | None = None,
41+
llm_config: LLMConfig | dict[str, Any] | None = None,
42+
blob_config: BlobConfig | dict[str, Any] | None = None,
43+
database_config: DatabaseConfig | dict[str, Any] | None = None,
44+
memorize_config: MemorizeConfig | dict[str, Any] | None = None,
45+
retrieve_config: RetrieveConfig | dict[str, Any] | None = None,
4546
):
4647
self.llm_config = self._validate_config(llm_config, LLMConfig)
4748
self.blob_config = self._validate_config(blob_config, BlobConfig)
4849
self.database_config = self._validate_config(database_config, DatabaseConfig)
4950
self.memorize_config = self._validate_config(memorize_config, MemorizeConfig)
51+
self.retrieve_config = self._validate_config(retrieve_config, RetrieveConfig)
5052
self.fs = LocalFS(self.blob_config.resources_dir)
5153
self.store = InMemoryStore()
5254
backend = self.llm_config.client_backend
@@ -788,19 +790,13 @@ async def retrieve(
788790
self,
789791
query: str,
790792
*,
791-
retrieve_config: dict[str, Any] | RetrieveConfig | None = None,
792793
conversation_history: list[dict[str, str]] | None = None,
793794
) -> dict[str, Any]:
794795
"""
795796
Retrieve relevant memories based on the query using either RAG-based or LLM-based search.
796797
797798
Args:
798799
query: The search query string
799-
retrieve_config: Configuration for retrieval method and parameters.
800-
Can be a dict or RetrieveConfig object with:
801-
- method: 'rag' for embedding-based vector search (default),
802-
'llm' for LLM-based semantic ranking
803-
- top_k: Maximum number of results per category (default: 5)
804800
conversation_history: Optional list of last 3 conversation turns, each with 'role' and 'content'.
805801
Example: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
806802
@@ -819,14 +815,6 @@ async def retrieve(
819815
- Pre-retrieval decision checks if retrieval is needed based on query type
820816
- Query rewriting incorporates conversation history for better context
821817
"""
822-
# Validate and resolve config
823-
config = self._validate_config(retrieve_config, RetrieveConfig)
824-
825-
# Validate method
826-
if config.method not in ("rag", "llm"):
827-
msg = f"Invalid retrieval method '{config.method}'. Must be 'rag' or 'llm'."
828-
raise ValueError(msg)
829-
830818
# Step 1: Decide if retrieval is needed
831819
needs_retrieval, rewritten_query = await self._decide_if_retrieval_needed(query, conversation_history)
832820

@@ -844,13 +832,13 @@ async def retrieve(
844832
logger.info(f"Query rewritten: '{query}' -> '{rewritten_query}'")
845833

846834
# Step 2: Perform retrieval with rewritten query using configured method
847-
if config.method == "llm":
835+
if self.retrieve_config.method == "llm":
848836
results = await self._llm_based_retrieve(
849-
rewritten_query, top_k=config.top_k, conversation_history=conversation_history
837+
rewritten_query, top_k=self.retrieve_config.top_k, conversation_history=conversation_history
850838
)
851839
else: # rag
852840
results = await self._embedding_based_retrieve(
853-
rewritten_query, top_k=config.top_k, conversation_history=conversation_history
841+
rewritten_query, top_k=self.retrieve_config.top_k, conversation_history=conversation_history
854842
)
855843

856844
# Add metadata

src/memu/app/settings.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
1-
from pydantic import BaseModel, Field
1+
from typing import Annotated, Literal
2+
3+
from pydantic import BaseModel, BeforeValidator, Field
24

35
from memu.prompts.memory_type import DEFAULT_MEMORY_TYPES
46
from memu.prompts.memory_type import PROMPTS as DEFAULT_MEMORY_TYPE_PROMPTS
57

68

9+
def normalize_value(v: str) -> str:
10+
if isinstance(v, str):
11+
return v.strip().lower()
12+
return v
13+
14+
15+
Normalize = BeforeValidator(normalize_value)
16+
17+
718
def _default_memory_types() -> list[str]:
819
return list(DEFAULT_MEMORY_TYPES)
920

@@ -56,10 +67,16 @@ class DatabaseConfig(BaseModel):
5667

5768

5869
class RetrieveConfig(BaseModel):
59-
method: str = Field(
60-
default="rag",
61-
description="Retrieval method: 'rag' for embedding-based vector search, 'llm' for LLM-based ranking.",
62-
)
70+
"""Configure retrieval behavior for `MemoryUser.retrieve`.
71+
72+
Attributes:
73+
method: Retrieval strategy. Use "rag" for embedding-based vector search or
74+
"llm" to delegate ranking to the LLM.
75+
top_k: Maximum number of results to return per category (and per stage),
76+
controlling breadth of the retrieved context.
77+
"""
78+
79+
method: Annotated[Literal["rag", "llm"], Normalize] = "rag"
6380
top_k: int = Field(
6481
default=5,
6582
description="Maximum number of results to return per category.",

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)