Skip to content

Commit f13cb14

Browse files
committed
updated tests added
1 parent e3cd44d commit f13cb14

1 file changed

Lines changed: 79 additions & 81 deletions

File tree

tests/test_gemini.py

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -21,52 +21,13 @@
2121
from memu.app import MemoryService
2222

2323

24-
async def test_gemini_full_workflow():
25-
"""Test Gemini integration with full MemU workflow."""
26-
api_key = os.environ.get("GEMINI_API_KEY")
27-
if not api_key:
28-
print("ERROR: GEMINI_API_KEY environment variable not set")
29-
print("Please set it with: export GEMINI_API_KEY=your_api_key")
30-
sys.exit(1)
31-
32-
# Use minimal conversation for free-tier rate limit friendly testing
33-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "example", "example_conversation_minimal.json"))
34-
if not os.path.exists(file_path):
35-
print(f"ERROR: Test file not found: {file_path}")
36-
sys.exit(1)
37-
38-
output_data = {}
39-
40-
print("\n" + "=" * 60)
41-
print("[GEMINI] Starting full workflow test...")
42-
print("=" * 60)
43-
44-
service = MemoryService(
45-
llm_profiles={
46-
"default": {
47-
"provider": "gemini",
48-
"client_backend": "httpx",
49-
"base_url": "https://generativelanguage.googleapis.com/v1beta",
50-
"api_key": api_key,
51-
"chat_model": "gemini-2.5-flash",
52-
"embed_model": "text-embedding-004",
53-
},
54-
},
55-
database_config={
56-
"metadata_store": {"provider": "inmemory"},
57-
},
58-
retrieve_config={
59-
"method": "rag",
60-
"route_intention": False,
61-
},
62-
)
63-
64-
# Test 1: Memorize conversation
65-
print("\n[GEMINI] Test 1: Memorizing conversation...")
24+
async def _test_memorize(service, file_path, user_id, output_data):
25+
"""Test memorization of conversation."""
26+
print(f"\n[GEMINI] Test 1: Memorizing conversation...")
6627
memory = await service.memorize(
6728
resource_url=file_path,
6829
modality="conversation",
69-
user={"user_id": "gemini_test_user"}
30+
user={"user_id": user_id}
7031
)
7132
items_count = len(memory.get("items", []))
7233
categories_count = len(memory.get("categories", []))
@@ -81,72 +42,108 @@ async def test_gemini_full_workflow():
8142

8243
for cat in memory.get("categories", [])[:3]:
8344
print(f" - {cat.get('name')}: {(cat.get('summary') or '')[:60]}...")
45+
return memory
8446

85-
queries = [
86-
{"role": "user", "content": {"text": "What foods does the user like to eat?"}},
87-
]
8847

89-
# Test 2: RAG-based retrieval
90-
print("\n[GEMINI] Test 2: RAG-based retrieval...")
91-
service.retrieve_config.method = "rag"
92-
result_rag = await service.retrieve(queries=queries, where={"user_id": "gemini_test_user"})
48+
async def _test_retrieval(service, queries, user_id, method, output_data, test_num, test_name):
49+
"""Test retrieval with specified method."""
50+
print(f"\n[GEMINI] Test {test_num}: {test_name}...")
51+
service.retrieve_config.method = method
52+
result = await service.retrieve(queries=queries, where={"user_id": user_id})
9353

94-
categories_retrieved = len(result_rag.get("categories", []))
95-
items_retrieved = len(result_rag.get("items", []))
54+
categories_retrieved = len(result.get("categories", []))
55+
items_retrieved = len(result.get("items", []))
9656

9757
print(f" Retrieved {categories_retrieved} categories")
9858
print(f" Retrieved {items_retrieved} items")
9959

100-
output_data["retrieve_rag"] = result_rag
60+
output_data[f"retrieve_{method}"] = result
10161

10262
if categories_retrieved > 0:
10363
print(" Categories:")
104-
for cat in result_rag.get("categories", [])[:3]:
64+
for cat in result.get("categories", [])[:3]:
10565
print(f" - {cat.get('name')}: {(cat.get('summary') or cat.get('description', ''))[:60]}...")
10666

10767
if items_retrieved > 0:
10868
print(" Items:")
109-
for item in result_rag.get("items", [])[:3]:
69+
for item in result.get("items", [])[:3]:
11070
print(f" - [{item.get('memory_type')}] {item.get('summary', '')[:80]}...")
71+
return result
11172

112-
# Test 3: LLM-based retrieval
113-
print("\n[GEMINI] Test 3: LLM-based retrieval...")
114-
service.retrieve_config.method = "llm"
115-
result_llm = await service.retrieve(queries=queries, where={"user_id": "gemini_test_user"})
116-
117-
categories_retrieved = len(result_llm.get("categories", []))
118-
items_retrieved = len(result_llm.get("items", []))
119-
120-
print(f" Retrieved {categories_retrieved} categories")
121-
print(f" Retrieved {items_retrieved} items")
122-
123-
output_data["retrieve_llm"] = result_llm
124-
125-
if categories_retrieved > 0:
126-
print(" Categories:")
127-
for cat in result_llm.get("categories", [])[:3]:
128-
print(f" - {cat.get('name')}: {(cat.get('summary') or cat.get('description', ''))[:60]}...")
129-
130-
if items_retrieved > 0:
131-
print(" Items:")
132-
for item in result_llm.get("items", [])[:3]:
133-
print(f" - [{item.get('memory_type')}] {item.get('summary', '')[:80]}...")
13473

135-
# Test 4: List memory items
74+
async def _test_list_items(service, user_id, output_data):
75+
"""Test listing memory items."""
13676
print("\n[GEMINI] Test 4: List memory items...")
137-
items_result = await service.list_memory_items(where={"user_id": "gemini_test_user"})
77+
items_result = await service.list_memory_items(where={"user_id": user_id})
13878
items_list = items_result.get("items", [])
13979
print(f" Listed {len(items_list)} memory items")
14080
output_data["list_items"] = items_result
14181
assert len(items_list) > 0, "Expected at least 1 item in list"
82+
return items_result
83+
14284

143-
# Test 5: List memory categories
85+
async def _test_list_categories(service, user_id, output_data):
86+
"""Test listing memory categories."""
14487
print("\n[GEMINI] Test 5: List memory categories...")
145-
cats_result = await service.list_memory_categories(where={"user_id": "gemini_test_user"})
88+
cats_result = await service.list_memory_categories(where={"user_id": user_id})
14689
cats_list = cats_result.get("categories", [])
14790
print(f" Listed {len(cats_list)} categories")
14891
output_data["list_categories"] = cats_result
14992
assert len(cats_list) > 0, "Expected at least 1 category in list"
93+
return cats_result
94+
95+
96+
async def test_gemini_full_workflow():
97+
"""Test Gemini integration with full MemU workflow."""
98+
api_key = os.environ.get("GEMINI_API_KEY")
99+
if not api_key:
100+
print("ERROR: GEMINI_API_KEY environment variable not set")
101+
print("Please set it with: export GEMINI_API_KEY=your_api_key")
102+
sys.exit(1)
103+
104+
# Use minimal conversation for free-tier rate limit friendly testing
105+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "examples", "resources", "conversations", "conv1.json"))
106+
if not os.path.exists(file_path):
107+
print(f"ERROR: Test file not found: {file_path}")
108+
sys.exit(1)
109+
110+
output_data = {}
111+
112+
print("\n" + "=" * 60)
113+
print("[GEMINI] Starting full workflow test...")
114+
print("=" * 60)
115+
116+
service = MemoryService(
117+
llm_profiles={
118+
"default": {
119+
"provider": "gemini",
120+
"client_backend": "httpx",
121+
"base_url": "https://generativelanguage.googleapis.com/v1beta",
122+
"api_key": api_key,
123+
"chat_model": "gemini-2.5-pro",
124+
"embed_model": "text-embedding-004",
125+
},
126+
},
127+
database_config={
128+
"metadata_store": {"provider": "inmemory"},
129+
},
130+
retrieve_config={
131+
"method": "rag",
132+
"route_intention": False,
133+
},
134+
)
135+
136+
user_id = "gemini_test_user"
137+
await _test_memorize(service, file_path, user_id, output_data)
138+
139+
queries = [
140+
{"role": "user", "content": {"text": "What foods does the user like to eat?"}},
141+
]
142+
143+
await _test_retrieval(service, queries, user_id, "rag", output_data, 2, "RAG-based retrieval")
144+
await _test_retrieval(service, queries, user_id, "llm", output_data, 3, "LLM-based retrieval")
145+
await _test_list_items(service, user_id, output_data)
146+
await _test_list_categories(service, user_id, output_data)
150147

151148
# Save output to file
152149
output_file = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "examples", "output", "gemini_test_output.json"))
@@ -162,3 +159,4 @@ async def test_gemini_full_workflow():
162159

163160
if __name__ == "__main__":
164161
asyncio.run(test_gemini_full_workflow())
162+

0 commit comments

Comments
 (0)