|
20 | 20 | from agent_memory_server.config import settings
|
21 | 21 | from agent_memory_server.extraction import extract_discrete_memories
|
22 | 22 | from agent_memory_server.llms import get_model_client
|
23 |
| -from agent_memory_server.models import MemoryRecord, MemoryTypeEnum |
24 | 23 |
|
25 | 24 |
|
26 | 25 | class GroundingEvaluationResult(BaseModel):
|
@@ -244,23 +243,39 @@ async def evaluate_grounding(
|
244 | 243 | class TestContextualGroundingIntegration:
|
245 | 244 | """Integration tests for contextual grounding with real LLM calls"""
|
246 | 245 |
|
247 |
| - async def create_test_memory_with_context( |
248 |
| - self, context_messages: list[str], target_message: str, context_date: datetime |
249 |
| - ) -> MemoryRecord: |
250 |
| - """Create a memory record with conversational context""" |
251 |
| - # Combine context messages and target message |
252 |
| - full_conversation = "\n".join(context_messages + [target_message]) |
253 |
| - |
254 |
| - return MemoryRecord( |
255 |
| - id=str(ulid.ULID()), |
256 |
| - text=full_conversation, |
257 |
| - memory_type=MemoryTypeEnum.MESSAGE, |
258 |
| - discrete_memory_extracted="f", |
259 |
| - session_id=f"test-integration-session-{ulid.ULID()}", |
| 246 | + async def create_test_conversation_with_context( |
| 247 | + self, all_messages: list[str], context_date: datetime, session_id: str |
| 248 | + ) -> str: |
| 249 | + """Create a test conversation with proper working memory setup for cross-message grounding""" |
| 250 | + from agent_memory_server.models import MemoryMessage, WorkingMemory |
| 251 | + from agent_memory_server.working_memory import set_working_memory |
| 252 | + |
| 253 | + # Create individual MemoryMessage objects for each message in the conversation |
| 254 | + messages = [] |
| 255 | + for i, message_text in enumerate(all_messages): |
| 256 | + messages.append( |
| 257 | + MemoryMessage( |
| 258 | + id=str(ulid.ULID()), |
| 259 | + role="user" if i % 2 == 0 else "assistant", |
| 260 | + content=message_text, |
| 261 | + timestamp=context_date.isoformat(), |
| 262 | + discrete_memory_extracted="f", |
| 263 | + ) |
| 264 | + ) |
| 265 | + |
| 266 | + # Create working memory with the conversation |
| 267 | + working_memory = WorkingMemory( |
| 268 | + session_id=session_id, |
260 | 269 | user_id="test-integration-user",
|
261 |
| - timestamp=context_date.isoformat(), |
| 270 | + namespace="test-namespace", |
| 271 | + messages=messages, |
| 272 | + memories=[], |
262 | 273 | )
|
263 | 274 |
|
| 275 | + # Store in working memory for thread-aware extraction |
| 276 | + await set_working_memory(working_memory) |
| 277 | + return session_id |
| 278 | + |
264 | 279 | async def test_pronoun_grounding_integration_he_him(self):
|
265 | 280 | """Integration test for he/him pronoun grounding with real LLM"""
|
266 | 281 | example = ContextualGroundingBenchmark.get_pronoun_grounding_examples()[0]
|
@@ -407,35 +422,31 @@ async def test_comprehensive_grounding_evaluation_with_judge(self):
|
407 | 422 | ] # Just first 2 for integration testing
|
408 | 423 |
|
409 | 424 | for example in sample_examples:
|
410 |
| - # Create memory and extract with real LLM |
411 |
| - memory = await self.create_test_memory_with_context( |
412 |
| - example["messages"][:-1], |
413 |
| - example["messages"][-1], |
414 |
| - example["context_date"], |
| 425 | + # Create a unique session for this test |
| 426 | + session_id = f"test-grounding-{ulid.ULID()}" |
| 427 | + |
| 428 | + # Set up proper conversation context for cross-message grounding |
| 429 | + await self.create_test_conversation_with_context( |
| 430 | + example["messages"], example["context_date"], session_id |
415 | 431 | )
|
416 | 432 |
|
417 | 433 | original_text = example["messages"][-1]
|
418 | 434 |
|
419 |
| - # Store and extract |
420 |
| - from agent_memory_server.vectorstore_factory import get_vectorstore_adapter |
421 |
| - |
422 |
| - adapter = await get_vectorstore_adapter() |
423 |
| - await adapter.add_memories([memory]) |
424 |
| - await extract_discrete_memories([memory]) |
| 435 | + # Use thread-aware extraction (the whole point of our implementation!) |
| 436 | + from agent_memory_server.long_term_memory import ( |
| 437 | + extract_memories_from_session_thread, |
| 438 | + ) |
425 | 439 |
|
426 |
| - # Retrieve all extracted discrete memories to get the grounded text |
427 |
| - all_memories = await adapter.search_memories(query="", limit=50) |
428 |
| - discrete_memories = [ |
429 |
| - m |
430 |
| - for m in all_memories.memories |
431 |
| - if m.memory_type in ["episodic", "semantic"] |
432 |
| - and m.session_id == memory.session_id |
433 |
| - ] |
| 440 | + extracted_memories = await extract_memories_from_session_thread( |
| 441 | + session_id=session_id, |
| 442 | + namespace="test-namespace", |
| 443 | + user_id="test-integration-user", |
| 444 | + ) |
434 | 445 |
|
435 | 446 | # Combine the grounded memories into a single text for evaluation
|
436 | 447 | grounded_text = (
|
437 |
| - " ".join([dm.text for dm in discrete_memories]) |
438 |
| - if discrete_memories |
| 448 | + " ".join([mem.text for mem in extracted_memories]) |
| 449 | + if extracted_memories |
439 | 450 | else original_text
|
440 | 451 | )
|
441 | 452 |
|
|
0 commit comments