|
20 | 20 | from agent_memory_server.config import settings |
21 | 21 | from agent_memory_server.extraction import extract_discrete_memories |
22 | 22 | from agent_memory_server.llms import get_model_client |
23 | | -from agent_memory_server.models import MemoryRecord, MemoryTypeEnum |
24 | 23 |
|
25 | 24 |
|
26 | 25 | class GroundingEvaluationResult(BaseModel): |
@@ -244,23 +243,39 @@ async def evaluate_grounding( |
244 | 243 | class TestContextualGroundingIntegration: |
245 | 244 | """Integration tests for contextual grounding with real LLM calls""" |
246 | 245 |
|
247 | | - async def create_test_memory_with_context( |
248 | | - self, context_messages: list[str], target_message: str, context_date: datetime |
249 | | - ) -> MemoryRecord: |
250 | | - """Create a memory record with conversational context""" |
251 | | - # Combine context messages and target message |
252 | | - full_conversation = "\n".join(context_messages + [target_message]) |
253 | | - |
254 | | - return MemoryRecord( |
255 | | - id=str(ulid.ULID()), |
256 | | - text=full_conversation, |
257 | | - memory_type=MemoryTypeEnum.MESSAGE, |
258 | | - discrete_memory_extracted="f", |
259 | | - session_id=f"test-integration-session-{ulid.ULID()}", |
| 246 | + async def create_test_conversation_with_context( |
| 247 | + self, all_messages: list[str], context_date: datetime, session_id: str |
| 248 | + ) -> str: |
| 249 | + """Create a test conversation with proper working memory setup for cross-message grounding""" |
| 250 | + from agent_memory_server.models import MemoryMessage, WorkingMemory |
| 251 | + from agent_memory_server.working_memory import set_working_memory |
| 252 | + |
| 253 | + # Create individual MemoryMessage objects for each message in the conversation |
| 254 | + messages = [] |
| 255 | + for i, message_text in enumerate(all_messages): |
| 256 | + messages.append( |
| 257 | + MemoryMessage( |
| 258 | + id=str(ulid.ULID()), |
| 259 | + role="user" if i % 2 == 0 else "assistant", |
| 260 | + content=message_text, |
| 261 | + timestamp=context_date.isoformat(), |
| 262 | + discrete_memory_extracted="f", |
| 263 | + ) |
| 264 | + ) |
| 265 | + |
| 266 | + # Create working memory with the conversation |
| 267 | + working_memory = WorkingMemory( |
| 268 | + session_id=session_id, |
260 | 269 | user_id="test-integration-user", |
261 | | - timestamp=context_date.isoformat(), |
| 270 | + namespace="test-namespace", |
| 271 | + messages=messages, |
| 272 | + memories=[], |
262 | 273 | ) |
263 | 274 |
|
| 275 | + # Store in working memory for thread-aware extraction |
| 276 | + await set_working_memory(working_memory) |
| 277 | + return session_id |
| 278 | + |
264 | 279 | async def test_pronoun_grounding_integration_he_him(self): |
265 | 280 | """Integration test for he/him pronoun grounding with real LLM""" |
266 | 281 | example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0] |
@@ -407,35 +422,31 @@ async def test_comprehensive_grounding_evaluation_with_judge(self): |
407 | 422 | ] # Just first 2 for integration testing |
408 | 423 |
|
409 | 424 | for example in sample_examples: |
410 | | - # Create memory and extract with real LLM |
411 | | - memory = await self.create_test_memory_with_context( |
412 | | - example["messages"][:-1], |
413 | | - example["messages"][-1], |
414 | | - example["context_date"], |
| 425 | + # Create a unique session for this test |
| 426 | + session_id = f"test-grounding-{ulid.ULID()}" |
| 427 | + |
| 428 | + # Set up proper conversation context for cross-message grounding |
| 429 | + await self.create_test_conversation_with_context( |
| 430 | + example["messages"], example["context_date"], session_id |
415 | 431 | ) |
416 | 432 |
|
417 | 433 | original_text = example["messages"][-1] |
418 | 434 |
|
419 | | - # Store and extract |
420 | | - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
421 | | - |
422 | | - adapter = await get_vectorstore_adapter() |
423 | | - await adapter.add_memories([memory]) |
424 | | - await extract_discrete_memories([memory]) |
| 435 | + # Use thread-aware extraction (the whole point of our implementation!) |
| 436 | + from agent_memory_server.long_term_memory import ( |
| 437 | + extract_memories_from_session_thread, |
| 438 | + ) |
425 | 439 |
|
426 | | - # Retrieve all extracted discrete memories to get the grounded text |
427 | | - all_memories = await adapter.search_memories(query="", limit=50) |
428 | | - discrete_memories = [ |
429 | | - m |
430 | | - for m in all_memories.memories |
431 | | - if m.memory_type in ["episodic", "semantic"] |
432 | | - and m.session_id == memory.session_id |
433 | | - ] |
| 440 | + extracted_memories = await extract_memories_from_session_thread( |
| 441 | + session_id=session_id, |
| 442 | + namespace="test-namespace", |
| 443 | + user_id="test-integration-user", |
| 444 | + ) |
434 | 445 |
|
435 | 446 | # Combine the grounded memories into a single text for evaluation |
436 | 447 | grounded_text = ( |
437 | | - " ".join([dm.text for dm in discrete_memories]) |
438 | | - if discrete_memories |
| 448 | + " ".join([mem.text for mem in extracted_memories]) |
| 449 | + if extracted_memories |
439 | 450 | else original_text |
440 | 451 | ) |
441 | 452 |
|
|
0 commit comments