Skip to content

Commit 69a9194

Browse files
committed
feat: support role content format
1 parent 1ce192f commit 69a9194

File tree

2 files changed

+104
-43
lines changed

2 files changed

+104
-43
lines changed

README.md

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pip install memu-py
101101
### Basic Usage
102102

103103
```python
104-
from memu.app import MemoryUser
104+
from memu.app import MemoryService
105105
import logging
106106

107107
async def test_memory_service():
@@ -112,47 +112,43 @@ async def test_memory_service():
112112
logger = logging.getLogger("memu")
113113
logger.setLevel(logging.DEBUG)
114114

115-
# Initialize MemoryUser with your OpenAI API key
116-
service = MemoryUser(llm_config={"api_key": "your-openai-api-key"})
115+
# Initialize MemoryService with your OpenAI API key
116+
service = MemoryService(llm_config={"api_key": "your-openai-api-key"})
117117

118118
# Memorize a conversation
119119
memory = await service.memorize(
120120
resource_url="tests/example/example_conversation.json",
121121
modality="conversation"
122122
)
123123

124-
# Example conversation history for query rewriting
125-
conversation_history = [
126-
{"role": "user", "content": "Tell me about the user's preferences"},
127-
{"role": "assistant", "content": "I'd be happy to help. Let me search the memory."},
128-
{"role": "user", "content": "What are their habits?"}
124+
# Test 1: RAG-based Retrieval with query context
125+
# Multiple queries enable automatic query rewriting with context
126+
print("\n[Test 1] RAG-based Retrieval with query context")
127+
queries_with_context = [
128+
{"role": "user", "content": {"text": "Tell me about the user's preferences"}},
129+
{"role": "assistant", "content": {"text": "I can help you with that. Let me search the memory."}},
130+
{"role": "user", "content": {"text": "What are their habits?"}},
129131
]
130-
131-
# Test 1: RAG-based Retrieval with conversation history
132-
print("\n[Test 1] RAG-based Retrieval with conversation history")
133-
retrieved_rag = await service.retrieve(
134-
query="What are their habits?",
135-
conversation_history=conversation_history,
136-
retrieve_config={"method": "rag", "top_k": 5}
137-
)
132+
retrieved_rag = await service.retrieve(queries=queries_with_context)
138133
print(f"Needs retrieval: {retrieved_rag.get('needs_retrieval')}")
139134
print(f"Original query: {retrieved_rag.get('original_query')}")
140135
print(f"Rewritten query: {retrieved_rag.get('rewritten_query')}")
136+
print(f"Next step query: {retrieved_rag.get('next_step_query')}")
141137
print(f"Results: {len(retrieved_rag.get('categories', []))} categories, "
142138
f"{len(retrieved_rag.get('items', []))} items")
143139

144-
# Test 2: LLM-based Retrieval with conversation history
145-
print("\n[Test 2] LLM-based Retrieval with conversation history")
146-
retrieved_llm = await service.retrieve(
147-
query="What are their habits?",
148-
conversation_history=conversation_history,
149-
retrieve_config={"method": "llm", "top_k": 5}
150-
)
151-
print(f"Needs retrieval: {retrieved_llm.get('needs_retrieval')}")
152-
print(f"Original query: {retrieved_llm.get('original_query')}")
153-
print(f"Rewritten query: {retrieved_llm.get('rewritten_query')}")
154-
print(f"Results: {len(retrieved_llm.get('categories', []))} categories, "
155-
f"{len(retrieved_llm.get('items', []))} items")
140+
# Test 2: Single query without context (no rewriting)
141+
print("\n[Test 2] Single query without context")
142+
queries_no_context = [
143+
{"role": "user", "content": {"text": "What are their habits?"}}
144+
]
145+
retrieved_single = await service.retrieve(queries=queries_no_context)
146+
print(f"Needs retrieval: {retrieved_single.get('needs_retrieval')}")
147+
print(f"Original query: {retrieved_single.get('original_query')}")
148+
print(f"Rewritten query: {retrieved_single.get('rewritten_query')}")
149+
print(f"Next step query: {retrieved_single.get('next_step_query')}")
150+
print(f"Results: {len(retrieved_single.get('categories', []))} categories, "
151+
f"{len(retrieved_single.get('items', []))} items")
156152

157153
if __name__ == "__main__":
158154
import asyncio
@@ -163,6 +159,22 @@ if __name__ == "__main__":
163159

164160
MemU provides two distinct retrieval approaches, each optimized for different scenarios:
165161

162+
#### **Query Structure**
163+
Queries are passed as a list of message objects in the format:
164+
```python
165+
[
166+
{"role": "user", "content": {"text": "Tell me about the user's preferences"}},
167+
{"role": "assistant", "content": {"text": "I can help you with that."}},
168+
{"role": "user", "content": {"text": "What are their habits?"}}
169+
]
170+
```
171+
172+
- **Roles** can be `user`, `assistant`, or other custom roles
173+
- The **last query** in the list is the current query
174+
- **Previous queries** (with their roles) provide context for automatic query rewriting
175+
- If only **one query** is provided, no rewriting occurs
176+
- The system returns a `next_step_query` to suggest the next retrieval step
177+
166178
#### **1. RAG-based Retrieval (`method="rag"`)**
167179
Fast embedding-based vector search using cosine similarity. Ideal for:
168180
- Large-scale datasets
@@ -190,9 +202,10 @@ This method uses the LLM to:
190202

191203
Both methods support:
192204
- **Full traceability**: Each retrieved item includes its `resource_id`, allowing you to trace back to the original source
193-
- **Conversation-aware rewriting**: Automatically resolves pronouns and references using conversation history
205+
- **Context-aware rewriting**: Automatically resolves pronouns and references using previous queries as context
194206
- **Pre-retrieval decision**: Intelligently determines if memory retrieval is needed for the query
195207
- **Progressive search**: Stops early if sufficient information is found at higher layers
208+
- **Next step suggestion**: Returns `next_step_query` for iterative multi-turn retrieval
196209

197210

198211

src/memu/app/service.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -785,13 +785,14 @@ def _validate_config(
785785

786786
async def retrieve(
787787
self,
788-
queries: list[str],
788+
queries: list[dict[str, Any]],
789789
) -> dict[str, Any]:
790790
"""
791791
Retrieve relevant memories based on the query using either RAG-based or LLM-based search.
792792
793793
Args:
794-
queries: List of query strings. The last one is the current query, others are context.
794+
queries: List of query messages in format [{"role": "user", "content": {"text": "..."}}].
795+
The last one is the current query, others are context.
795796
If list has only 1 element, no query rewriting is performed.
796797
797798
Returns:
@@ -813,12 +814,13 @@ async def retrieve(
813814
if not queries:
814815
raise ValueError("empty_queries")
815816

816-
current_query = queries[-1]
817-
context_queries = queries[:-1] if len(queries) > 1 else []
817+
# Extract text from the query structure
818+
current_query = self._extract_query_text(queries[-1])
819+
context_queries_objs = queries[:-1] if len(queries) > 1 else []
818820

819821
# Step 1: Decide if retrieval is needed
820822
needs_retrieval, rewritten_query = await self._decide_if_retrieval_needed(
821-
current_query, context_queries, retrieved_content=None
823+
current_query, context_queries_objs, retrieved_content=None
822824
)
823825

824826
# If only one query, do not use the rewritten version (use original)
@@ -842,11 +844,11 @@ async def retrieve(
842844
# Step 2: Perform retrieval with rewritten query using configured method
843845
if self.retrieve_config.method == "llm":
844846
results = await self._llm_based_retrieve(
845-
rewritten_query, top_k=self.retrieve_config.top_k, context_queries=context_queries
847+
rewritten_query, top_k=self.retrieve_config.top_k, context_queries=context_queries_objs
846848
)
847849
else: # rag
848850
results = await self._embedding_based_retrieve(
849-
rewritten_query, top_k=self.retrieve_config.top_k, context_queries=context_queries
851+
rewritten_query, top_k=self.retrieve_config.top_k, context_queries=context_queries_objs
850852
)
851853

852854
# Add metadata
@@ -874,7 +876,7 @@ async def _rank_categories_by_summary(
874876
async def _decide_if_retrieval_needed(
875877
self,
876878
query: str,
877-
context_queries: list[str] | None,
879+
context_queries: list[dict[str, Any]] | None,
878880
retrieved_content: str | None = None,
879881
system_prompt: str | None = None,
880882
) -> tuple[bool, str]:
@@ -883,7 +885,7 @@ async def _decide_if_retrieval_needed(
883885
884886
Args:
885887
query: The current query string
886-
context_queries: List of context queries
888+
context_queries: List of previous query objects with role and content
887889
retrieved_content: Content retrieved so far (if checking for sufficiency)
888890
system_prompt: Optional system prompt override
889891
@@ -908,17 +910,61 @@ async def _decide_if_retrieval_needed(
908910

909911
return decision == "RETRIEVE", rewritten
910912

911-
def _format_query_context(self, queries: list[str] | None) -> str:
912-
"""Format query context for prompts"""
913+
def _format_query_context(self, queries: list[dict[str, Any]] | None) -> str:
914+
"""Format query context for prompts, including role information"""
913915
if not queries:
914916
return "No query context."
915917

916918
lines = []
917919
for q in queries:
918-
lines.append(f"- {q}")
920+
if isinstance(q, str):
921+
# Backward compatibility
922+
lines.append(f"- {q}")
923+
elif isinstance(q, dict):
924+
role = q.get("role", "user")
925+
content = q.get("content")
926+
if isinstance(content, dict):
927+
text = content.get("text", "")
928+
elif isinstance(content, str):
929+
text = content
930+
else:
931+
text = str(content)
932+
lines.append(f"- [{role}]: {text}")
933+
else:
934+
lines.append(f"- {q!s}")
919935

920936
return "\n".join(lines)
921937

938+
@staticmethod
939+
def _extract_query_text(query: dict[str, Any]) -> str:
940+
"""
941+
Extract text content from query message structure.
942+
943+
Args:
944+
query: Query in format {"role": "user", "content": {"text": "..."}}
945+
946+
Returns:
947+
The extracted text string
948+
"""
949+
if isinstance(query, str):
950+
# Backward compatibility: if it's already a string, return it
951+
return query
952+
953+
if not isinstance(query, dict):
954+
raise TypeError("INVALID")
955+
956+
content = query.get("content")
957+
if isinstance(content, dict):
958+
text = content.get("text", "")
959+
if not text:
960+
raise ValueError("EMPTY")
961+
return str(text)
962+
elif isinstance(content, str):
963+
# Also support {"role": "user", "content": "text"} format
964+
return content
965+
else:
966+
raise TypeError("INVALID")
967+
922968
def _extract_decision(self, raw: str) -> str:
923969
"""Extract RETRIEVE or NO_RETRIEVE decision from LLM response"""
924970
if not raw:
@@ -946,7 +992,7 @@ def _extract_rewritten_query(self, raw: str) -> str | None:
946992
return None
947993

948994
async def _embedding_based_retrieve(
949-
self, query: str, top_k: int, context_queries: list[str] | None
995+
self, query: str, top_k: int, context_queries: list[dict[str, Any]] | None
950996
) -> dict[str, Any]:
951997
"""Embedding-based retrieval with query rewriting and judging at each tier"""
952998
current_query = query
@@ -1056,7 +1102,9 @@ def _extract_judgement(self, raw: str) -> str:
10561102
return "ENOUGH"
10571103
return "MORE"
10581104

1059-
async def _llm_based_retrieve(self, query: str, top_k: int, context_queries: list[str] | None) -> dict[str, Any]:
1105+
async def _llm_based_retrieve(
1106+
self, query: str, top_k: int, context_queries: list[dict[str, Any]] | None
1107+
) -> dict[str, Any]:
10601108
"""
10611109
LLM-based retrieval that uses language model to search and rank results
10621110
in a hierarchical manner, with query rewriting and judging at each tier.

0 commit comments

Comments
 (0)