2121from memu .app import MemoryService
2222
2323
24- async def test_gemini_full_workflow ():
25- """Test Gemini integration with full MemU workflow."""
26- api_key = os .environ .get ("GEMINI_API_KEY" )
27- if not api_key :
28- print ("ERROR: GEMINI_API_KEY environment variable not set" )
29- print ("Please set it with: export GEMINI_API_KEY=your_api_key" )
30- sys .exit (1 )
31-
32- # Use minimal conversation for free-tier rate limit friendly testing
33- file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), "example" , "example_conversation_minimal.json" ))
34- if not os .path .exists (file_path ):
35- print (f"ERROR: Test file not found: { file_path } " )
36- sys .exit (1 )
37-
38- output_data = {}
39-
40- print ("\n " + "=" * 60 )
41- print ("[GEMINI] Starting full workflow test..." )
42- print ("=" * 60 )
43-
44- service = MemoryService (
45- llm_profiles = {
46- "default" : {
47- "provider" : "gemini" ,
48- "client_backend" : "httpx" ,
49- "base_url" : "https://generativelanguage.googleapis.com/v1beta" ,
50- "api_key" : api_key ,
51- "chat_model" : "gemini-2.5-flash" ,
52- "embed_model" : "text-embedding-004" ,
53- },
54- },
55- database_config = {
56- "metadata_store" : {"provider" : "inmemory" },
57- },
58- retrieve_config = {
59- "method" : "rag" ,
60- "route_intention" : False ,
61- },
62- )
63-
64- # Test 1: Memorize conversation
65- print ("\n [GEMINI] Test 1: Memorizing conversation..." )
24+ async def _test_memorize (service , file_path , user_id , output_data ):
25+ """Test memorization of conversation."""
26+ print (f"\n [GEMINI] Test 1: Memorizing conversation..." )
6627 memory = await service .memorize (
6728 resource_url = file_path ,
6829 modality = "conversation" ,
69- user = {"user_id" : "gemini_test_user" }
30+ user = {"user_id" : user_id }
7031 )
7132 items_count = len (memory .get ("items" , []))
7233 categories_count = len (memory .get ("categories" , []))
@@ -81,72 +42,108 @@ async def test_gemini_full_workflow():
8142
8243 for cat in memory .get ("categories" , [])[:3 ]:
8344 print (f" - { cat .get ('name' )} : { (cat .get ('summary' ) or '' )[:60 ]} ..." )
45+ return memory
8446
85- queries = [
86- {"role" : "user" , "content" : {"text" : "What foods does the user like to eat?" }},
87- ]
8847
89- # Test 2: RAG-based retrieval
90- print ("\n [GEMINI] Test 2: RAG-based retrieval..." )
91- service .retrieve_config .method = "rag"
92- result_rag = await service .retrieve (queries = queries , where = {"user_id" : "gemini_test_user" })
48+ async def _test_retrieval (service , queries , user_id , method , output_data , test_num , test_name ):
49+ """Test retrieval with specified method."""
50+ print (f"\n [GEMINI] Test { test_num } : { test_name } ..." )
51+ service .retrieve_config .method = method
52+ result = await service .retrieve (queries = queries , where = {"user_id" : user_id })
9353
94- categories_retrieved = len (result_rag .get ("categories" , []))
95- items_retrieved = len (result_rag .get ("items" , []))
54+ categories_retrieved = len (result .get ("categories" , []))
55+ items_retrieved = len (result .get ("items" , []))
9656
9757 print (f" Retrieved { categories_retrieved } categories" )
9858 print (f" Retrieved { items_retrieved } items" )
9959
100- output_data ["retrieve_rag " ] = result_rag
60+ output_data [f"retrieve_ { method } " ] = result
10161
10262 if categories_retrieved > 0 :
10363 print (" Categories:" )
104- for cat in result_rag .get ("categories" , [])[:3 ]:
64+ for cat in result .get ("categories" , [])[:3 ]:
10565 print (f" - { cat .get ('name' )} : { (cat .get ('summary' ) or cat .get ('description' , '' ))[:60 ]} ..." )
10666
10767 if items_retrieved > 0 :
10868 print (" Items:" )
109- for item in result_rag .get ("items" , [])[:3 ]:
69+ for item in result .get ("items" , [])[:3 ]:
11070 print (f" - [{ item .get ('memory_type' )} ] { item .get ('summary' , '' )[:80 ]} ..." )
71+ return result
11172
112- # Test 3: LLM-based retrieval
113- print ("\n [GEMINI] Test 3: LLM-based retrieval..." )
114- service .retrieve_config .method = "llm"
115- result_llm = await service .retrieve (queries = queries , where = {"user_id" : "gemini_test_user" })
116-
117- categories_retrieved = len (result_llm .get ("categories" , []))
118- items_retrieved = len (result_llm .get ("items" , []))
119-
120- print (f" Retrieved { categories_retrieved } categories" )
121- print (f" Retrieved { items_retrieved } items" )
122-
123- output_data ["retrieve_llm" ] = result_llm
124-
125- if categories_retrieved > 0 :
126- print (" Categories:" )
127- for cat in result_llm .get ("categories" , [])[:3 ]:
128- print (f" - { cat .get ('name' )} : { (cat .get ('summary' ) or cat .get ('description' , '' ))[:60 ]} ..." )
129-
130- if items_retrieved > 0 :
131- print (" Items:" )
132- for item in result_llm .get ("items" , [])[:3 ]:
133- print (f" - [{ item .get ('memory_type' )} ] { item .get ('summary' , '' )[:80 ]} ..." )
13473
135- # Test 4: List memory items
74+ async def _test_list_items (service , user_id , output_data ):
75+ """Test listing memory items."""
13676 print ("\n [GEMINI] Test 4: List memory items..." )
137- items_result = await service .list_memory_items (where = {"user_id" : "gemini_test_user" })
77+ items_result = await service .list_memory_items (where = {"user_id" : user_id })
13878 items_list = items_result .get ("items" , [])
13979 print (f" Listed { len (items_list )} memory items" )
14080 output_data ["list_items" ] = items_result
14181 assert len (items_list ) > 0 , "Expected at least 1 item in list"
82+ return items_result
83+
14284
143- # Test 5: List memory categories
85+ async def _test_list_categories (service , user_id , output_data ):
86+ """Test listing memory categories."""
14487 print ("\n [GEMINI] Test 5: List memory categories..." )
145- cats_result = await service .list_memory_categories (where = {"user_id" : "gemini_test_user" })
88+ cats_result = await service .list_memory_categories (where = {"user_id" : user_id })
14689 cats_list = cats_result .get ("categories" , [])
14790 print (f" Listed { len (cats_list )} categories" )
14891 output_data ["list_categories" ] = cats_result
14992 assert len (cats_list ) > 0 , "Expected at least 1 category in list"
93+ return cats_result
94+
95+
96+ async def test_gemini_full_workflow ():
97+ """Test Gemini integration with full MemU workflow."""
98+ api_key = os .environ .get ("GEMINI_API_KEY" )
99+ if not api_key :
100+ print ("ERROR: GEMINI_API_KEY environment variable not set" )
101+ print ("Please set it with: export GEMINI_API_KEY=your_api_key" )
102+ sys .exit (1 )
103+
104+ # Use minimal conversation for free-tier rate limit friendly testing
105+ file_path = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , "examples" , "resources" , "conversations" , "conv1.json" ))
106+ if not os .path .exists (file_path ):
107+ print (f"ERROR: Test file not found: { file_path } " )
108+ sys .exit (1 )
109+
110+ output_data = {}
111+
112+ print ("\n " + "=" * 60 )
113+ print ("[GEMINI] Starting full workflow test..." )
114+ print ("=" * 60 )
115+
116+ service = MemoryService (
117+ llm_profiles = {
118+ "default" : {
119+ "provider" : "gemini" ,
120+ "client_backend" : "httpx" ,
121+ "base_url" : "https://generativelanguage.googleapis.com/v1beta" ,
122+ "api_key" : api_key ,
123+ "chat_model" : "gemini-2.5-pro" ,
124+ "embed_model" : "text-embedding-004" ,
125+ },
126+ },
127+ database_config = {
128+ "metadata_store" : {"provider" : "inmemory" },
129+ },
130+ retrieve_config = {
131+ "method" : "rag" ,
132+ "route_intention" : False ,
133+ },
134+ )
135+
136+ user_id = "gemini_test_user"
137+ await _test_memorize (service , file_path , user_id , output_data )
138+
139+ queries = [
140+ {"role" : "user" , "content" : {"text" : "What foods does the user like to eat?" }},
141+ ]
142+
143+ await _test_retrieval (service , queries , user_id , "rag" , output_data , 2 , "RAG-based retrieval" )
144+ await _test_retrieval (service , queries , user_id , "llm" , output_data , 3 , "LLM-based retrieval" )
145+ await _test_list_items (service , user_id , output_data )
146+ await _test_list_categories (service , user_id , output_data )
150147
151148 # Save output to file
152149 output_file = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." , "examples" , "output" , "gemini_test_output.json" ))
@@ -162,3 +159,4 @@ async def test_gemini_full_workflow():
162159
163160if __name__ == "__main__" :
164161 asyncio .run (test_gemini_full_workflow ())
162+
0 commit comments